diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index a42272367..9a94c95e8 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -38,6 +38,7 @@ CompileExample("t15_nodes_mocking") CompileExample("t16_global_blackboard") CompileExample("t17_blackboard_backup") CompileExample("t18_waypoints") +CompileExample("t19_polymorphic_ports") CompileExample("ex01_wrap_legacy") CompileExample("ex02_runtime_ports") diff --git a/examples/t19_polymorphic_ports.cpp b/examples/t19_polymorphic_ports.cpp new file mode 100644 index 000000000..9e58232f5 --- /dev/null +++ b/examples/t19_polymorphic_ports.cpp @@ -0,0 +1,162 @@ +#include "behaviortree_cpp/bt_factory.h" + +using namespace BT; + +/* This tutorial shows how to use polymorphic ports. + * + * When nodes produce and consume shared_ptr via ports, + * you may want a node that outputs shared_ptr to feed + * into a node that expects shared_ptr. + * + * By registering the inheritance relationship with + * factory.registerPolymorphicCast(), the library + * handles the upcast automatically — both at tree-creation time + * (port type validation) and at runtime (getInput / get). + * + * Transitive casts are supported: if you register A->B and B->C, + * then A->C works automatically. + */ + +//-------------------------------------------------------------- +// A simple class hierarchy +//-------------------------------------------------------------- + +class Animal +{ +public: + using Ptr = std::shared_ptr; + virtual ~Animal() = default; + + virtual std::string name() const + { + return "Animal"; + } +}; + +class Cat : public Animal +{ +public: + using Ptr = std::shared_ptr; + + std::string name() const override + { + return "Cat"; + } +}; + +class Sphynx : public Cat +{ +public: + using Ptr = std::shared_ptr; + + std::string name() const override + { + return "Sphynx"; + } +}; + +//-------------------------------------------------------------- +// Nodes that produce derived types +//-------------------------------------------------------------- + +class CreateCat : public SyncActionNode +{ +public: + CreateCat(const std::string& name, const NodeConfig& config) + : SyncActionNode(name, config) + {} + + NodeStatus tick() override + { + setOutput("animal", std::make_shared()); + return NodeStatus::SUCCESS; + } + + static PortsList providedPorts() + { + return { OutputPort("animal") }; + } +}; + +class CreateSphynx : public SyncActionNode +{ +public: + CreateSphynx(const std::string& name, const NodeConfig& config) + : SyncActionNode(name, config) + {} + + NodeStatus tick() override + { + setOutput("animal", std::make_shared()); + return NodeStatus::SUCCESS; + } + + static PortsList providedPorts() + { + return { OutputPort("animal") }; + } +}; + +//-------------------------------------------------------------- +// A node that consumes the base type +//-------------------------------------------------------------- + +class SayHi : public SyncActionNode +{ +public: + SayHi(const std::string& name, const NodeConfig& config) : SyncActionNode(name, config) + {} + + NodeStatus tick() override + { + auto animal = getInput("animal").value(); + std::cout << "Hi! I am a " << animal->name() << std::endl; + return NodeStatus::SUCCESS; + } + + static PortsList providedPorts() + { + return { InputPort("animal") }; + } +}; + +//-------------------------------------------------------------- + +// clang-format off +static const char* xml_text = R"( + + + + + + + + + + +)"; +// clang-format on + +int main() +{ + BehaviorTreeFactory factory; + + // Register the inheritance relationships. + // This is what makes Cat::Ptr and Sphynx::Ptr assignable to Animal::Ptr ports. + factory.registerPolymorphicCast(); + factory.registerPolymorphicCast(); + + factory.registerNodeType("CreateCat"); + factory.registerNodeType("CreateSphynx"); + factory.registerNodeType("SayHi"); + + auto tree = factory.createTreeFromText(xml_text); + tree.tickWhileRunning(); + + /* Expected output: + * + * Hi! I am a Cat + * Hi! I am a Sphynx + */ + return 0; +} diff --git a/include/behaviortree_cpp/blackboard.h b/include/behaviortree_cpp/blackboard.h index f661a96c0..7229b5921 100644 --- a/include/behaviortree_cpp/blackboard.h +++ b/include/behaviortree_cpp/blackboard.h @@ -4,6 +4,7 @@ #include "behaviortree_cpp/contrib/json.hpp" #include "behaviortree_cpp/exceptions.h" #include "behaviortree_cpp/utils/locked_reference.hpp" +#include "behaviortree_cpp/utils/polymorphic_cast_registry.hpp" #include "behaviortree_cpp/utils/safe_any.hpp" #include @@ -149,6 +150,34 @@ class Blackboard const Blackboard* rootBlackboard() const; + /** + * @brief Set the polymorphic cast registry for this blackboard. + * + * The registry enables polymorphic shared_ptr conversions during get(). + * This is typically set automatically when creating trees via BehaviorTreeFactory. + */ + void setPolymorphicCastRegistry(std::shared_ptr registry) + { + polymorphic_registry_ = std::move(registry); + } + + /** + * @brief Get the polymorphic cast registry (may be null). + */ + [[nodiscard]] const PolymorphicCastRegistry* polymorphicCastRegistry() const + { + return polymorphic_registry_.get(); + } + + /** + * @brief Cast Any value with polymorphic fallback for shared_ptr types. + * + * First attempts a direct cast. If that fails and T is a shared_ptr type, + * tries a polymorphic cast via the registry. Returns Expected with error on failure. + */ + template + [[nodiscard]] Expected tryCastWithPolymorphicFallback(const Any* any) const; + private: mutable std::mutex storage_mutex_; mutable std::recursive_mutex entry_mutex_; @@ -159,6 +188,9 @@ class Blackboard std::shared_ptr createEntryImpl(const std::string& key, const TypeInfo& info); bool autoremapping_ = false; + + // Optional registry for polymorphic shared_ptr conversions + std::shared_ptr polymorphic_registry_; }; /** @@ -177,6 +209,32 @@ void ImportBlackboardFromJSON(const nlohmann::json& json, Blackboard& blackboard //------------------------------------------------------ +template +inline Expected Blackboard::tryCastWithPolymorphicFallback(const Any* any) const +{ + // Try direct cast first + auto result = any->tryCast(); + if(result) + { + return result.value(); + } + + // For shared_ptr types, try polymorphic cast via registry (Issue #943) + if constexpr(is_shared_ptr::value) + { + if(polymorphic_registry_) + { + auto poly_result = any->tryCastWithRegistry(*polymorphic_registry_); + if(poly_result) + { + return poly_result.value(); + } + } + } + + return nonstd::make_unexpected(result.error()); +} + template inline T Blackboard::get(const std::string& key) const { @@ -188,7 +246,12 @@ inline T Blackboard::get(const std::string& key) const throw RuntimeError("Blackboard::get() error. Entry [", key, "] hasn't been initialized, yet"); } - return any_ref.get()->cast(); + auto result = tryCastWithPolymorphicFallback(any); + if(!result) + { + throw std::runtime_error(result.error()); + } + return result.value(); } throw RuntimeError("Blackboard::get() error. Missing key [", key, "]"); } @@ -325,11 +388,17 @@ inline bool Blackboard::get(const std::string& key, T& value) const { if(auto any_ref = getAnyLocked(key)) { - if(any_ref.get()->empty()) + const auto& any = any_ref.get(); + if(any->empty()) { return false; } - value = any_ref.get()->cast(); + auto result = tryCastWithPolymorphicFallback(any); + if(!result) + { + throw std::runtime_error(result.error()); + } + value = result.value(); return true; } return false; @@ -346,8 +415,13 @@ inline Expected Blackboard::getStamped(const std::string& key, T& val return nonstd::make_unexpected(StrCat("Blackboard::getStamped() error. Entry [", key, "] hasn't been initialized, yet")); } - value = entry->value.cast(); - return Timestamp{ entry->sequence_id, entry->stamp }; + auto result = tryCastWithPolymorphicFallback(&entry->value); + if(result) + { + value = result.value(); + return Timestamp{ entry->sequence_id, entry->stamp }; + } + return nonstd::make_unexpected(result.error()); } return nonstd::make_unexpected( StrCat("Blackboard::getStamped() error. Missing key [", key, "]")); diff --git a/include/behaviortree_cpp/bt_factory.h b/include/behaviortree_cpp/bt_factory.h index 7f28fdc1d..242fa17fc 100644 --- a/include/behaviortree_cpp/bt_factory.h +++ b/include/behaviortree_cpp/bt_factory.h @@ -17,6 +17,7 @@ #include "behaviortree_cpp/behavior_tree.h" #include "behaviortree_cpp/contrib/json.hpp" #include "behaviortree_cpp/contrib/magic_enum.hpp" +#include "behaviortree_cpp/utils/polymorphic_cast_registry.hpp" #include #include @@ -529,6 +530,45 @@ class BehaviorTreeFactory [[nodiscard]] const std::unordered_map& substitutionRules() const; + /** + * @brief Register a polymorphic cast relationship between Derived and Base types. + * + * This enables passing shared_ptr to ports expecting shared_ptr + * without type mismatch errors. The relationship is automatically applied + * to all trees created from this factory. + * + * Example: + * factory.registerPolymorphicCast(); + * factory.registerPolymorphicCast(); + * + * @tparam Derived The derived class (must inherit from Base) + * @tparam Base The base class (must be polymorphic) + */ + template + void registerPolymorphicCast() + { + polymorphicCastRegistry().registerCast(); + } + + /** + * @brief Access the polymorphic cast registry. + * + * The registry is shared with all trees created from this factory, + * allowing trees to outlive the factory while maintaining access + * to polymorphic cast relationships. + */ + [[nodiscard]] PolymorphicCastRegistry& polymorphicCastRegistry(); + [[nodiscard]] const PolymorphicCastRegistry& polymorphicCastRegistry() const; + + /** + * @brief Get a shared pointer to the polymorphic cast registry. + * + * This allows trees and blackboards to hold a reference to the registry + * that outlives the factory. + */ + [[nodiscard]] std::shared_ptr + polymorphicCastRegistryPtr() const; + private: struct PImpl; std::unique_ptr _p; diff --git a/include/behaviortree_cpp/tree_node.h b/include/behaviortree_cpp/tree_node.h index e68363247..9fe48568d 100644 --- a/include/behaviortree_cpp/tree_node.h +++ b/include/behaviortree_cpp/tree_node.h @@ -591,7 +591,13 @@ inline Expected TreeNode::getInputStamped(const std::string& key, } else { - destination = any_value.cast(); + auto result = + config().blackboard->tryCastWithPolymorphicFallback(&any_value); + if(!result) + { + throw std::runtime_error(result.error()); + } + destination = result.value(); } return Timestamp{ entry->sequence_id, entry->stamp }; } diff --git a/include/behaviortree_cpp/utils/polymorphic_cast_registry.hpp b/include/behaviortree_cpp/utils/polymorphic_cast_registry.hpp new file mode 100644 index 000000000..6b8880273 --- /dev/null +++ b/include/behaviortree_cpp/utils/polymorphic_cast_registry.hpp @@ -0,0 +1,392 @@ +/* Copyright (C) 2022-2025 Davide Faconti - All Rights Reserved +* +* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ + +#pragma once + +#include "behaviortree_cpp/contrib/any.hpp" +#include "behaviortree_cpp/contrib/expected.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace BT +{ + +/** + * @brief Registry for polymorphic shared_ptr cast relationships. + * + * This enables passing shared_ptr to ports expecting shared_ptr + * without breaking ABI compatibility. Users register inheritance relationships + * at runtime, and the registry handles upcasting/downcasting transparently. + * + * This class is typically owned by BehaviorTreeFactory and passed to Blackboard + * during tree creation. This avoids global state and makes testing easier. + * + * Usage with BehaviorTreeFactory: + * BehaviorTreeFactory factory; + * factory.registerPolymorphicCast(); + * factory.registerPolymorphicCast(); + * auto tree = factory.createTreeFromText(xml); + */ +class PolymorphicCastRegistry +{ +public: + using CastFunction = std::function; + + PolymorphicCastRegistry() = default; + ~PolymorphicCastRegistry() = default; + + // Non-copyable, non-movable (contains mutex) + PolymorphicCastRegistry(const PolymorphicCastRegistry&) = delete; + PolymorphicCastRegistry& operator=(const PolymorphicCastRegistry&) = delete; + PolymorphicCastRegistry(PolymorphicCastRegistry&&) = delete; + PolymorphicCastRegistry& operator=(PolymorphicCastRegistry&&) = delete; + + /** + * @brief Register a Derived -> Base inheritance relationship. + * + * This enables: + * - Upcasting: shared_ptr can be retrieved as shared_ptr + * - Downcasting: shared_ptr can be retrieved as shared_ptr + * (via dynamic_pointer_cast, may return nullptr if types don't match) + * + * @tparam Derived The derived class (must inherit from Base) + * @tparam Base The base class (must be polymorphic - have virtual functions) + */ + template + void registerCast() + { + static_assert(std::is_base_of_v, "Derived must inherit from Base"); + static_assert(std::is_polymorphic_v, "Base must be polymorphic (have virtual " + "functions)"); + + std::unique_lock lock(mutex_); + + // Register upcast: Derived -> Base + auto upcast_key = std::make_pair(std::type_index(typeid(std::shared_ptr)), + std::type_index(typeid(std::shared_ptr))); + + upcasts_[upcast_key] = [](const linb::any& from) -> linb::any { + auto ptr = linb::any_cast>(from); + return std::static_pointer_cast(ptr); + }; + + // Register downcast: Base -> Derived (uses dynamic_pointer_cast) + auto downcast_key = std::make_pair(std::type_index(typeid(std::shared_ptr)), + std::type_index(typeid(std::shared_ptr))); + + downcasts_[downcast_key] = [](const linb::any& from) -> linb::any { + auto ptr = linb::any_cast>(from); + auto derived = std::dynamic_pointer_cast(ptr); + if(!derived) + { + throw std::bad_cast(); + } + return derived; + }; + + // Track inheritance relationship for port compatibility checks + base_types_[std::type_index(typeid(std::shared_ptr))].insert( + std::type_index(typeid(std::shared_ptr))); + } + + /** + * @brief Check if from_type can be converted to to_type. + * + * Returns true if: + * - from_type == to_type + * - from_type is a registered derived type of to_type (upcast) + * - to_type is a registered derived type of from_type (downcast) + */ + [[nodiscard]] bool isConvertible(std::type_index from_type, + std::type_index to_type) const + { + if(from_type == to_type) + { + return true; + } + + std::shared_lock lock(mutex_); + + // Check direct upcast + auto upcast_key = std::make_pair(from_type, to_type); + if(upcasts_.find(upcast_key) != upcasts_.end()) + { + return true; + } + + // Check transitive upcast (e.g., Sphynx -> Cat -> Animal) + if(canUpcastTransitive(from_type, to_type)) + { + return true; + } + + // Check downcast + auto downcast_key = std::make_pair(from_type, to_type); + if(downcasts_.find(downcast_key) != downcasts_.end()) + { + return true; + } + + return false; + } + + /** + * @brief Check if from_type can be UPCAST to to_type (not downcast). + * + * This is stricter than isConvertible - only allows going from + * derived to base, not the reverse. + */ + [[nodiscard]] bool canUpcast(std::type_index from_type, std::type_index to_type) const + { + if(from_type == to_type) + { + return true; + } + + std::shared_lock lock(mutex_); + return canUpcastTransitive(from_type, to_type); + } + + /** + * @brief Attempt to cast the value to the target type. + * + * @param from The source any containing a shared_ptr + * @param from_type The type_index of the stored type + * @param to_type The target type_index + * @return The casted any on success, or an error string on failure + */ + [[nodiscard]] nonstd::expected + tryCast(const linb::any& from, std::type_index from_type, std::type_index to_type) const + { + if(from_type == to_type) + { + return from; + } + + std::shared_lock lock(mutex_); + + // Try direct upcast + auto upcast_key = std::make_pair(from_type, to_type); + auto upcast_it = upcasts_.find(upcast_key); + if(upcast_it != upcasts_.end()) + { + try + { + return upcast_it->second(from); + } + catch(const std::exception& e) + { + return nonstd::make_unexpected(std::string("Direct upcast failed: ") + e.what()); + } + } + + // Try transitive upcast + auto transitive_up = applyTransitiveCasts(from, from_type, to_type, upcasts_, true); + if(transitive_up) + { + return transitive_up; + } + + // Try direct downcast + auto downcast_key = std::make_pair(from_type, to_type); + auto downcast_it = downcasts_.find(downcast_key); + if(downcast_it != downcasts_.end()) + { + try + { + return downcast_it->second(from); + } + catch(const std::exception& e) + { + return nonstd::make_unexpected(std::string("Downcast failed " + "(dynamic_pointer_cast returned " + "null): ") + + e.what()); + } + } + + // Try transitive downcast + auto transitive_down = + applyTransitiveCasts(from, to_type, from_type, downcasts_, false); + if(transitive_down) + { + return transitive_down; + } + + return nonstd::make_unexpected(std::string("No registered polymorphic conversion " + "available")); + } + + /** + * @brief Get all registered base types for a given type. + */ + [[nodiscard]] std::set getBaseTypes(std::type_index type) const + { + std::shared_lock lock(mutex_); + auto it = base_types_.find(type); + if(it != base_types_.end()) + { + return it->second; + } + return {}; + } + + /** + * @brief Clear all registrations (mainly for testing). + */ + void clear() + { + std::unique_lock lock(mutex_); + upcasts_.clear(); + downcasts_.clear(); + base_types_.clear(); + } + +private: + // Check if we can upcast from_type to to_type through a chain of registered casts + [[nodiscard]] bool canUpcastTransitive(std::type_index from_type, + std::type_index to_type) const + { + // Depth-first search to find a path from from_type to to_type + std::set visited; + std::vector queue; + queue.push_back(from_type); + + while(!queue.empty()) + { + auto current = queue.back(); + queue.pop_back(); + + if(visited.count(current) != 0) + { + continue; + } + visited.insert(current); + + auto it = base_types_.find(current); + if(it == base_types_.end()) + { + continue; + } + + for(const auto& base : it->second) + { + if(base == to_type) + { + return true; + } + queue.push_back(base); + } + } + return false; + } + + // Common helper for transitive upcast and downcast. + // + // Performs depth-first search from dfs_start through base_types_ edges, + // looking for dfs_target. When found, reconstructs the path from dfs_target + // back to dfs_start. If reverse_path is true, reverses it so casts are applied + // in [dfs_start -> dfs_target] order; otherwise applies in traced order. + // + // For upcast: dfs_start=from_type, dfs_target=to_type, reverse=true, map=upcasts_ + // For downcast: dfs_start=to_type, dfs_target=from_type, reverse=false, map=downcasts_ + using CastMap = std::map, CastFunction>; + + [[nodiscard]] nonstd::expected applyTransitiveCasts( + const linb::any& from, std::type_index dfs_start, std::type_index dfs_target, + const CastMap& cast_map, bool reverse_path) const + { + // Note: std::type_index has no default constructor, so we can't use operator[] + std::map parent; + std::vector stack; + stack.push_back(dfs_start); + parent.insert({ dfs_start, dfs_start }); + + while(!stack.empty()) + { + auto current = stack.back(); + stack.pop_back(); + + auto it = base_types_.find(current); + if(it == base_types_.end()) + { + continue; + } + + for(const auto& base : it->second) + { + if(parent.find(base) != parent.end()) + { + continue; + } + parent.insert({ base, current }); + if(base == dfs_target) + { + // Reconstruct path: trace from dfs_target back to dfs_start + std::vector path; + auto node = dfs_target; + while(node != dfs_start) + { + path.push_back(node); + node = parent.at(node); + } + path.push_back(dfs_start); + + if(reverse_path) + { + std::reverse(path.begin(), path.end()); + } + + // Apply casts along the path + linb::any current_value = from; + for(size_t i = 0; i + 1 < path.size(); ++i) + { + auto cast_key = std::make_pair(path[i], path[i + 1]); + auto cast_it = cast_map.find(cast_key); + if(cast_it == cast_map.end()) + { + return nonstd::make_unexpected(std::string("Transitive cast: missing step " + "in chain")); + } + try + { + current_value = cast_it->second(current_value); + } + catch(const std::exception& e) + { + return nonstd::make_unexpected(std::string("Transitive cast step " + "failed: ") + + e.what()); + } + } + return current_value; + } + stack.push_back(base); + } + } + return nonstd::make_unexpected(std::string("No transitive path found")); + } + + mutable std::shared_mutex mutex_; + std::map, CastFunction> upcasts_; + std::map, CastFunction> downcasts_; + std::map> base_types_; +}; + +} // namespace BT diff --git a/include/behaviortree_cpp/utils/safe_any.hpp b/include/behaviortree_cpp/utils/safe_any.hpp index e4d5ef98c..3a87caf8f 100644 --- a/include/behaviortree_cpp/utils/safe_any.hpp +++ b/include/behaviortree_cpp/utils/safe_any.hpp @@ -20,8 +20,10 @@ #include "behaviortree_cpp/contrib/expected.hpp" #include "behaviortree_cpp/utils/convert_impl.hpp" #include "behaviortree_cpp/utils/demangle_util.h" +#include "behaviortree_cpp/utils/polymorphic_cast_registry.hpp" #include "behaviortree_cpp/utils/strcat.hpp" +#include #include #include #include @@ -31,6 +33,17 @@ namespace BT static std::type_index UndefinedAnyType = typeid(nullptr); +// Trait to detect std::shared_ptr types (used for polymorphic port support) +template +struct is_shared_ptr : std::false_type +{ +}; + +template +struct is_shared_ptr> : std::true_type +{ +}; + // Rational: since type erased numbers will always use at least 8 bytes // it is faster to cast everything to either double, uint64_t or int64_t. class Any @@ -144,6 +157,12 @@ class Any template nonstd::expected tryCast() const; + // tryCast with polymorphic registry support (Issue #943) + // Attempts polymorphic cast for shared_ptr types using the provided registry. + template + nonstd::expected + tryCastWithRegistry(const PolymorphicCastRegistry& registry) const; + // same as tryCast, but throws if fails template [[nodiscard]] T cast() const @@ -549,4 +568,35 @@ inline nonstd::expected Any::tryCast() const } } +template +inline nonstd::expected +Any::tryCastWithRegistry(const PolymorphicCastRegistry& registry) const +{ + static_assert(is_shared_ptr::value, "tryCastWithRegistry only works with shared_ptr " + "types"); + + if(_any.empty()) + { + return nonstd::make_unexpected("Any::tryCastWithRegistry failed: empty value"); + } + + // Try to cast using the registry + auto result = registry.tryCast(_any, _original_type, typeid(T)); + if(result) + { + try + { + return linb::any_cast(result.value()); + } + catch(const std::exception& e) + { + return nonstd::make_unexpected(StrCat("Polymorphic cast failed: ", e.what())); + } + } + + return nonstd::make_unexpected(StrCat("[Any::tryCastWithRegistry]: ", result.error(), + " (from [", demangle(_original_type), "] to [", + demangle(typeid(T)), "])")); +} + } // end namespace BT diff --git a/src/bt_factory.cpp b/src/bt_factory.cpp index 002aab3a1..432ae6abd 100644 --- a/src/bt_factory.cpp +++ b/src/bt_factory.cpp @@ -108,11 +108,13 @@ struct BehaviorTreeFactory::PImpl std::shared_ptr> scripting_enums; std::shared_ptr parser; std::unordered_map substitution_rules; + std::shared_ptr polymorphic_registry; }; BehaviorTreeFactory::BehaviorTreeFactory() : _p(new PImpl) { _p->parser = std::make_shared(*this); + _p->polymorphic_registry = std::make_shared(); registerNodeType("Fallback"); registerNodeType("AsyncFallback", true); registerNodeType("Sequence"); @@ -427,6 +429,9 @@ Tree BehaviorTreeFactory::createTreeFromText(const std::string& text, const std::string resolved_ID = loadXmlAndResolveTreeId( _p->parser.get(), main_tree_ID, [&] { _p->parser->loadFromText(text); }); + // Set the polymorphic cast registry on the blackboard (Issue #943) + blackboard->setPolymorphicCastRegistry(_p->polymorphic_registry); + Tree tree = resolved_ID.empty() ? _p->parser->instantiateTree(blackboard) : _p->parser->instantiateTree(blackboard, resolved_ID); tree.manifests = this->manifests(); @@ -449,6 +454,9 @@ Tree BehaviorTreeFactory::createTreeFromFile(const std::filesystem::path& file_p const std::string resolved_ID = loadXmlAndResolveTreeId( _p->parser.get(), main_tree_ID, [&] { _p->parser->loadFromFile(file_path); }); + // Set the polymorphic cast registry on the blackboard (Issue #943) + blackboard->setPolymorphicCastRegistry(_p->polymorphic_registry); + Tree tree = resolved_ID.empty() ? _p->parser->instantiateTree(blackboard) : _p->parser->instantiateTree(blackboard, resolved_ID); tree.manifests = this->manifests(); @@ -459,6 +467,9 @@ Tree BehaviorTreeFactory::createTreeFromFile(const std::filesystem::path& file_p Tree BehaviorTreeFactory::createTree(const std::string& tree_name, Blackboard::Ptr blackboard) { + // Set the polymorphic cast registry on the blackboard (Issue #943) + blackboard->setPolymorphicCastRegistry(_p->polymorphic_registry); + auto tree = _p->parser->instantiateTree(blackboard, tree_name); tree.manifests = this->manifests(); tree.remapManifestPointers(); @@ -564,6 +575,22 @@ BehaviorTreeFactory::substitutionRules() const return _p->substitution_rules; } +PolymorphicCastRegistry& BehaviorTreeFactory::polymorphicCastRegistry() +{ + return *_p->polymorphic_registry; +} + +const PolymorphicCastRegistry& BehaviorTreeFactory::polymorphicCastRegistry() const +{ + return *_p->polymorphic_registry; +} + +std::shared_ptr +BehaviorTreeFactory::polymorphicCastRegistryPtr() const +{ + return _p->polymorphic_registry; +} + Tree::Tree() = default; void Tree::remapManifestPointers() diff --git a/src/xml_parsing.cpp b/src/xml_parsing.cpp index 3636b96d5..fbe1d74e5 100644 --- a/src/xml_parsing.cpp +++ b/src/xml_parsing.cpp @@ -41,6 +41,7 @@ #pragma GCC diagnostic pop #endif +#include "behaviortree_cpp/utils/polymorphic_cast_registry.hpp" #include "behaviortree_cpp/xml_parsing.h" #include @@ -928,10 +929,22 @@ TreeNode::Ptr XMLParser::PImpl::createNodeFromXML(const XMLElement* element, if(auto prev_info = blackboard->entryInfo(port_key)) { // Check consistency of types. - bool const port_type_mismatch = + bool port_type_mismatch = (prev_info->isStronglyTyped() && port_info.isStronglyTyped() && prev_info->type() != port_info.type()); + // Allow polymorphic cast for INPUT ports (Issue #943) + // If a registered conversion exists (upcast or downcast), allow the + // connection. Downcasts use dynamic_pointer_cast and may fail at runtime. + if(port_type_mismatch && port_info.direction() == PortDirection::INPUT) + { + if(factory->polymorphicCastRegistry().isConvertible(prev_info->type(), + port_info.type())) + { + port_type_mismatch = false; + } + } + // special case related to convertFromString bool const string_input = (prev_info->type() == typeid(std::string)); @@ -1055,6 +1068,8 @@ void BT::XMLParser::PImpl::recursivelyCreateSubtree( else // special case: SubTreeNode { auto new_bb = Blackboard::create(blackboard); + // Inherit polymorphic cast registry from factory (Issue #943) + new_bb->setPolymorphicCastRegistry(factory->polymorphicCastRegistryPtr()); const std::string subtree_ID = element->Attribute("ID"); std::unordered_map subtree_remapping; bool do_autoremap = false; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 861c461c5..7c5418c1d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -52,6 +52,7 @@ set(BT_TESTS gtest_while_do_else.cpp gtest_interface.cpp gtest_simple_string.cpp + gtest_polymorphic_ports.cpp gtest_plugin_issue953.cpp script_parser_test.cpp diff --git a/tests/gtest_polymorphic_ports.cpp b/tests/gtest_polymorphic_ports.cpp new file mode 100644 index 000000000..3faf0e6a5 --- /dev/null +++ b/tests/gtest_polymorphic_ports.cpp @@ -0,0 +1,496 @@ +#include "behaviortree_cpp/bt_factory.h" + +#include + +#include "include/animal_hierarchy_test.h" + +using namespace BT; + +//------------------------------------------------------------------- +// Any-level polymorphic cast tests (registry) +//------------------------------------------------------------------- + +TEST(PolymorphicPortTest, AnyCast_SameType) +{ + PolymorphicCastRegistry registry; + registry.registerCast(); + registry.registerCast(); + registry.registerCast(); + + auto animal = std::make_shared(); + Any any_animal(animal); + EXPECT_NO_THROW(auto res = any_animal.cast()); + // Downcast should fail (returns error, doesn't throw) + EXPECT_FALSE(any_animal.tryCastWithRegistry(registry).has_value()); + EXPECT_FALSE(any_animal.tryCastWithRegistry(registry).has_value()); +} + +TEST(PolymorphicPortTest, AnyCast_Upcast) +{ + PolymorphicCastRegistry registry; + registry.registerCast(); + registry.registerCast(); + registry.registerCast(); + + auto cat = std::make_shared(); + Any any_cat(cat); + // Same type works + EXPECT_NO_THROW(auto res = any_cat.cast()); + // Upcast via registry + auto result = any_cat.tryCastWithRegistry(registry); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value()->name(), "Cat"); + // Downcast should fail + EXPECT_FALSE(any_cat.tryCastWithRegistry(registry).has_value()); +} + +TEST(PolymorphicPortTest, AnyCast_TransitiveUpcast) +{ + PolymorphicCastRegistry registry; + registry.registerCast(); + registry.registerCast(); + registry.registerCast(); + + auto sphynx = std::make_shared(); + Any any_sphynx(sphynx); + // Same type works + EXPECT_NO_THROW(auto res = any_sphynx.cast()); + // Upcast to Cat + auto as_cat = any_sphynx.tryCastWithRegistry(registry); + EXPECT_TRUE(as_cat.has_value()); + EXPECT_EQ(as_cat.value()->name(), "Sphynx"); + // Transitive upcast to Animal + auto as_animal = any_sphynx.tryCastWithRegistry(registry); + EXPECT_TRUE(as_animal.has_value()); + EXPECT_EQ(as_animal.value()->name(), "Sphynx"); +} + +TEST(PolymorphicPortTest, AnyCast_DowncastWithRuntimeTypeCheck) +{ + PolymorphicCastRegistry registry; + registry.registerCast(); + registry.registerCast(); + + Cat::Ptr cat = std::make_shared(); // Store Sphynx as Cat + Any any_cat(cat); + // Same type works + EXPECT_NO_THROW(auto res = any_cat.cast()); + // Downcast should work because runtime type is Sphynx + auto as_sphynx = any_cat.tryCastWithRegistry(registry); + EXPECT_TRUE(as_sphynx.has_value()); + EXPECT_EQ(as_sphynx.value()->name(), "Sphynx"); +} + +TEST(PolymorphicPortTest, AnyCast_UnrelatedTypes) +{ + PolymorphicCastRegistry registry; + registry.registerCast(); + registry.registerCast(); + + auto cat = std::make_shared(); + Any any_cat(cat); + EXPECT_FALSE(any_cat.tryCastWithRegistry(registry).has_value()); + + auto dog = std::make_shared(); + Any any_dog(dog); + EXPECT_FALSE(any_dog.tryCastWithRegistry(registry).has_value()); +} + +//------------------------------------------------------------------- +// Test nodes for XML-level polymorphic port testing +//------------------------------------------------------------------- + +class CreateAnimal : public SyncActionNode +{ +public: + CreateAnimal(const std::string& name, const NodeConfig& config) + : SyncActionNode(name, config) + {} + + NodeStatus tick() override + { + setOutput("out_animal", std::make_shared()); + return NodeStatus::SUCCESS; + } + + static PortsList providedPorts() + { + return { OutputPort("out_animal") }; + } +}; + +class CreateCat : public SyncActionNode +{ +public: + CreateCat(const std::string& name, const NodeConfig& config) + : SyncActionNode(name, config) + {} + + NodeStatus tick() override + { + setOutput("out_cat", std::make_shared()); + return NodeStatus::SUCCESS; + } + + static PortsList providedPorts() + { + return { OutputPort("out_cat") }; + } +}; + +class CreateSphynx : public SyncActionNode +{ +public: + CreateSphynx(const std::string& name, const NodeConfig& config) + : SyncActionNode(name, config) + {} + + NodeStatus tick() override + { + setOutput("out_sphynx", std::make_shared()); + return NodeStatus::SUCCESS; + } + + static PortsList providedPorts() + { + return { OutputPort("out_sphynx") }; + } +}; + +class CreateDog : public SyncActionNode +{ +public: + CreateDog(const std::string& name, const NodeConfig& config) + : SyncActionNode(name, config) + {} + + NodeStatus tick() override + { + setOutput("out_dog", std::make_shared()); + return NodeStatus::SUCCESS; + } + + static PortsList providedPorts() + { + return { OutputPort("out_dog") }; + } +}; + +class CreateCatAsAnimal : public SyncActionNode +{ +public: + CreateCatAsAnimal(const std::string& name, const NodeConfig& config) + : SyncActionNode(name, config) + {} + + NodeStatus tick() override + { + setOutput("out_animal", Animal::Ptr(std::make_shared())); + return NodeStatus::SUCCESS; + } + + static PortsList providedPorts() + { + return { OutputPort("out_animal") }; + } +}; + +class PrintAnimalName : public SyncActionNode +{ +public: + PrintAnimalName(const std::string& name, const NodeConfig& config) + : SyncActionNode(name, config) + {} + + NodeStatus tick() override + { + Animal::Ptr animal; + if(!getInput("in_animal", animal) || !animal) + { + return NodeStatus::FAILURE; + } + last_name_ = animal->name(); + return NodeStatus::SUCCESS; + } + + static PortsList providedPorts() + { + return { InputPort("in_animal") }; + } + + static std::string last_name_; +}; +std::string PrintAnimalName::last_name_; + +class PrintCatName : public SyncActionNode +{ +public: + PrintCatName(const std::string& name, const NodeConfig& config) + : SyncActionNode(name, config) + {} + + NodeStatus tick() override + { + Cat::Ptr cat; + if(!getInput("in_cat", cat) || !cat) + { + return NodeStatus::FAILURE; + } + last_name_ = cat->name(); + return NodeStatus::SUCCESS; + } + + static PortsList providedPorts() + { + return { InputPort("in_cat") }; + } + + static std::string last_name_; +}; +std::string PrintCatName::last_name_; + +class PrintDogName : public SyncActionNode +{ +public: + PrintDogName(const std::string& name, const NodeConfig& config) + : SyncActionNode(name, config) + {} + + NodeStatus tick() override + { + Dog::Ptr dog; + if(!getInput("in_dog", dog) || !dog) + { + return NodeStatus::FAILURE; + } + last_name_ = dog->name(); + return NodeStatus::SUCCESS; + } + + static PortsList providedPorts() + { + return { InputPort("in_dog") }; + } + + static std::string last_name_; +}; +std::string PrintDogName::last_name_; + +//------------------------------------------------------------------- +// Blackboard-level polymorphic get/set tests +//------------------------------------------------------------------- + +namespace +{ +Blackboard::Ptr createBlackboardWithRegistry() +{ + auto bb = Blackboard::create(); + auto registry = std::make_shared(); + registry->registerCast(); + registry->registerCast(); + registry->registerCast(); + bb->setPolymorphicCastRegistry(registry); + return bb; +} +} // namespace + +TEST(PolymorphicPortTest, Blackboard_UpcastAndDowncast) +{ + auto bb = createBlackboardWithRegistry(); + + // Store a Cat, retrieve as Animal (upcast) + auto cat = std::make_shared(); + bb->set("pet", cat); + + Animal::Ptr animal; + ASSERT_TRUE(bb->get("pet", animal)); + ASSERT_EQ(animal->name(), "Cat"); + + // Can still get as Cat + Cat::Ptr retrieved_cat; + ASSERT_TRUE(bb->get("pet", retrieved_cat)); + ASSERT_EQ(retrieved_cat->name(), "Cat"); + + // Cannot get as Sphynx (invalid downcast) + Sphynx::Ptr sphynx; + ASSERT_ANY_THROW((void)bb->get("pet", sphynx)); +} + +TEST(PolymorphicPortTest, Blackboard_TransitiveUpcast) +{ + auto bb = createBlackboardWithRegistry(); + + auto sphynx = std::make_shared(); + bb->set("pet", sphynx); + + // Can get as Animal (transitive upcast through Cat) + Animal::Ptr animal; + ASSERT_TRUE(bb->get("pet", animal)); + ASSERT_EQ(animal->name(), "Sphynx"); + + // Can get as Cat (direct upcast) + Cat::Ptr cat; + ASSERT_TRUE(bb->get("pet", cat)); + ASSERT_EQ(cat->name(), "Sphynx"); + + // Can get as Sphynx (same type) + Sphynx::Ptr retrieved_sphynx; + ASSERT_TRUE(bb->get("pet", retrieved_sphynx)); + ASSERT_EQ(retrieved_sphynx->name(), "Sphynx"); +} + +//------------------------------------------------------------------- +// XML tree-level polymorphic port tests +//------------------------------------------------------------------- + +TEST(PolymorphicPortTest, XML_ValidUpcast) +{ + std::string xml_txt = R"( + + + + + + + + + )"; + + BehaviorTreeFactory factory; + RegisterAnimalHierarchy(factory); + factory.registerNodeType("CreateCat"); + factory.registerNodeType("PrintCatName"); + factory.registerNodeType("PrintAnimalName"); + + auto tree = factory.createTreeFromText(xml_txt); + NodeStatus status = tree.tickWhileRunning(); + + ASSERT_EQ(status, NodeStatus::SUCCESS); + ASSERT_EQ(PrintCatName::last_name_, "Cat"); + ASSERT_EQ(PrintAnimalName::last_name_, "Cat"); +} + +TEST(PolymorphicPortTest, XML_TransitiveUpcast) +{ + std::string xml_txt = R"( + + + + + + + + )"; + + BehaviorTreeFactory factory; + RegisterAnimalHierarchy(factory); + factory.registerNodeType("CreateSphynx"); + factory.registerNodeType("PrintAnimalName"); + + auto tree = factory.createTreeFromText(xml_txt); + NodeStatus status = tree.tickWhileRunning(); + + ASSERT_EQ(status, NodeStatus::SUCCESS); + ASSERT_EQ(PrintAnimalName::last_name_, "Sphynx"); +} + +TEST(PolymorphicPortTest, XML_InoutRejectsTypeMismatch) +{ + class UpdateAnimal : public SyncActionNode + { + public: + UpdateAnimal(const std::string& name, const NodeConfig& config) + : SyncActionNode(name, config) + {} + NodeStatus tick() override + { + return NodeStatus::SUCCESS; + } + static PortsList providedPorts() + { + return { BidirectionalPort("animal") }; + } + }; + + std::string xml_txt = R"( + + + + + + + + )"; + + BehaviorTreeFactory factory; + RegisterAnimalHierarchy(factory); + factory.registerNodeType("CreateCat"); + factory.registerNodeType("UpdateAnimal"); + + ASSERT_ANY_THROW((void)factory.createTreeFromText(xml_txt)); +} + +TEST(PolymorphicPortTest, XML_InvalidConnection_UnrelatedTypes) +{ + std::string xml_txt = R"( + + + + + + + + )"; + + BehaviorTreeFactory factory; + factory.registerNodeType("CreateCat"); + factory.registerNodeType("PrintDogName"); + + ASSERT_ANY_THROW((void)factory.createTreeFromText(xml_txt)); +} + +TEST(PolymorphicPortTest, XML_DowncastSucceedsAtRuntime) +{ + std::string xml_txt = R"( + + + + + + + + )"; + + BehaviorTreeFactory factory; + RegisterAnimalHierarchy(factory); + factory.registerNodeType("CreateCatAsAnimal"); + factory.registerNodeType("PrintCatName"); + + auto tree = factory.createTreeFromText(xml_txt); + NodeStatus status = tree.tickWhileRunning(); + + ASSERT_EQ(status, NodeStatus::SUCCESS); + ASSERT_EQ(PrintCatName::last_name_, "Cat"); +} + +TEST(PolymorphicPortTest, XML_DowncastFailsAtRuntime) +{ + std::string xml_txt = R"( + + + + + + + + )"; + + BehaviorTreeFactory factory; + RegisterAnimalHierarchy(factory); + factory.registerNodeType("CreateAnimal"); + factory.registerNodeType("PrintCatName"); + + auto tree = factory.createTreeFromText(xml_txt); + // Runtime should fail (actual type is Animal, not Cat) + ASSERT_EQ(tree.tickWhileRunning(), NodeStatus::FAILURE); +} diff --git a/tests/include/animal_hierarchy_test.h b/tests/include/animal_hierarchy_test.h new file mode 100644 index 000000000..0454c6bf2 --- /dev/null +++ b/tests/include/animal_hierarchy_test.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include + +// Animal hierarchy for testing polymorphic port connections +// +// Animal +// / \ +// Cat Dog +// | +// Sphynx + +class Animal +{ +public: + using Ptr = std::shared_ptr; + + Animal() = default; + virtual ~Animal() = default; + Animal(const Animal&) = default; + Animal& operator=(const Animal&) = default; + Animal(Animal&&) = default; + Animal& operator=(Animal&&) = default; + + virtual std::string name() const + { + return "Animal"; + } +}; + +class Cat : public Animal +{ +public: + using Ptr = std::shared_ptr; + + std::string name() const override + { + return "Cat"; + } + + void meow() const + {} +}; + +class Dog : public Animal +{ +public: + using Ptr = std::shared_ptr; + + std::string name() const override + { + return "Dog"; + } + + void bark() const + {} +}; + +class Sphynx : public Cat +{ +public: + using Ptr = std::shared_ptr; + + std::string name() const override + { + return "Sphynx"; + } +}; + +// Helper to register the animal hierarchy with a factory +// Usage: RegisterAnimalHierarchy(factory); +#include "behaviortree_cpp/bt_factory.h" +inline void RegisterAnimalHierarchy(BT::BehaviorTreeFactory& factory) +{ + factory.registerPolymorphicCast(); + factory.registerPolymorphicCast(); + factory.registerPolymorphicCast(); +}