diff --git a/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs b/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs index fd5998c93b3..9d7097aaf94 100644 --- a/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs +++ b/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs @@ -50,6 +50,48 @@ public static TValue Deserialize(this IBsonSerializer serializer return serializer.Deserialize(context, args); } + /// + /// Gets the serializer for a base type starting from a serializer for a derived type. + /// + /// The serializer for the derived type. + /// The base type. + /// The serializer for the base type. + public static IBsonSerializer GetBaseTypeSerializer(this IBsonSerializer derivedTypeSerializer, Type baseType) + { + if (derivedTypeSerializer.ValueType == baseType) + { + return derivedTypeSerializer; + } + + if (!baseType.IsAssignableFrom(derivedTypeSerializer.ValueType)) + { + throw new ArgumentException($"{baseType} is not assignable from {derivedTypeSerializer.ValueType}."); + } + + return BsonSerializer.LookupSerializer(baseType); // TODO: should be able to navigate from serializer + } + + /// + /// Gets the serializer for a derived type starting from a serializer for a base type. + /// + /// The serializer for the base type. + /// The derived type. + /// The serializer for the derived type. + public static IBsonSerializer GetDerivedTypeSerializer(this IBsonSerializer baseTypeSerializer, Type derivedType) + { + if (baseTypeSerializer.ValueType == derivedType) + { + return baseTypeSerializer; + } + + if (!baseTypeSerializer.ValueType.IsAssignableFrom(derivedType)) + { + throw new ArgumentException($"{baseTypeSerializer.ValueType} is not assignable from {derivedType}."); + } + + return BsonSerializer.LookupSerializer(derivedType); // TODO: should be able to navigate from serializer + } + /// /// Gets the discriminator convention for a serializer. /// diff --git a/src/MongoDB.Bson/Serialization/Serializers/ArraySerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/ArraySerializer.cs index f10cb541d16..e90210fbc14 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/ArraySerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/ArraySerializer.cs @@ -13,10 +13,29 @@ * limitations under the License. */ +using System; using System.Collections.Generic; namespace MongoDB.Bson.Serialization.Serializers { + /// + /// A static factory class for ArraySerializers. + /// + public static class ArraySerializer + { + /// + /// Creates an ArraySerializer. + /// + /// The item serializer. + /// An ArraySerializer. + public static IBsonSerializer Create(IBsonSerializer itemSerializer) + { + var itemType = itemSerializer.ValueType; + var arraySerializerType = typeof(ArraySerializer<>).MakeGenericType(itemType); + return (IBsonSerializer)Activator.CreateInstance(arraySerializerType, itemSerializer); + } + } + /// /// Represents a serializer for one-dimensional arrays. /// diff --git a/src/MongoDB.Bson/Serialization/Serializers/CharSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/CharSerializer.cs index a1688526364..787720f34cf 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/CharSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/CharSerializer.cs @@ -22,6 +22,15 @@ namespace MongoDB.Bson.Serialization.Serializers /// public sealed class CharSerializer : StructSerializerBase, IRepresentationConfigurable { + #region static + private static readonly CharSerializer __instance = new(); + + /// + /// Returns the default instance of CharSerializer. + /// + public static CharSerializer Instance => __instance; + #endregion + // private fields private readonly BsonType _representation; diff --git a/src/MongoDB.Bson/Serialization/Serializers/DictionarySerializerBase.cs b/src/MongoDB.Bson/Serialization/Serializers/DictionarySerializerBase.cs index 3f6ff642bdd..2535cbed080 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/DictionarySerializerBase.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/DictionarySerializerBase.cs @@ -499,20 +499,15 @@ obj is DictionarySerializerBase other && /// public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationInfo) { - if (_dictionaryRepresentation is DictionaryRepresentation.ArrayOfArrays or DictionaryRepresentation.ArrayOfDocuments) - { - var representation = _dictionaryRepresentation == DictionaryRepresentation.ArrayOfArrays - ? BsonType.Array - : BsonType.Document; - var keySerializer = _lazyKeySerializer.Value; - var valueSerializer = _lazyValueSerializer.Value; - var keyValuePairSerializer = new KeyValuePairSerializer(representation, keySerializer, valueSerializer); - serializationInfo = new BsonSerializationInfo(null, keyValuePairSerializer, keyValuePairSerializer.ValueType); - return true; - } - - serializationInfo = null; - return false; + var representation = _dictionaryRepresentation == DictionaryRepresentation.ArrayOfArrays + ? BsonType.Array + : BsonType.Document; + var keySerializer = _lazyKeySerializer.Value; + var valueSerializer = _lazyValueSerializer.Value; + var keyValuePairSerializer = new KeyValuePairSerializer(representation, keySerializer, valueSerializer); + + serializationInfo = new BsonSerializationInfo(null, keyValuePairSerializer, keyValuePairSerializer.ValueType); + return true; } /// diff --git a/src/MongoDB.Bson/Serialization/Serializers/NullableSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/NullableSerializer.cs index 8740bdd3a9b..423b9500bed 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/NullableSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/NullableSerializer.cs @@ -33,6 +33,73 @@ public interface INullableSerializer /// public static class NullableSerializer { + private readonly static IBsonSerializer __nullableBooleanInstance = new NullableSerializer(BooleanSerializer.Instance); + private readonly static IBsonSerializer __nullableDecimalInstance = new NullableSerializer(DecimalSerializer.Instance); + private readonly static IBsonSerializer __nullableDecimal128Instance = new NullableSerializer(Decimal128Serializer.Instance); + private readonly static IBsonSerializer __nullableDoubleInstance = new NullableSerializer(DoubleSerializer.Instance); + private readonly static IBsonSerializer __nullableInt32Instance = new NullableSerializer(Int32Serializer.Instance); + private readonly static IBsonSerializer __nullableInt64Instance = new NullableSerializer(Int64Serializer.Instance); + private readonly static IBsonSerializer __nullableLocalDateTimeInstance = new NullableSerializer(DateTimeSerializer.LocalInstance); + private readonly static IBsonSerializer __nullableObjectIdInstance = new NullableSerializer(ObjectIdSerializer.Instance); + private readonly static IBsonSerializer __nullableSingleInstance = new NullableSerializer(SingleSerializer.Instance); + private readonly static IBsonSerializer __nullableStandardGuidInstance = new NullableSerializer(GuidSerializer.StandardInstance); + private readonly static IBsonSerializer __nullableUtcDateTimeInstance = new NullableSerializer(DateTimeSerializer.UtcInstance); + + /// + /// Gets a serializer for nullable bools. + /// + public static IBsonSerializer NullableBooleanInstance => __nullableBooleanInstance; + + /// + /// Gets a serializer for nullable decimals. + /// + public static IBsonSerializer NullableDecimalInstance => __nullableDecimalInstance; + + /// + /// Gets a serializer for nullable Decimal128s. + /// + public static IBsonSerializer NullableDecimal128Instance => __nullableDecimal128Instance; + + /// + /// Gets a serializer for nullable doubles. + /// + public static IBsonSerializer NullableDoubleInstance => __nullableDoubleInstance; + + /// + /// Gets a serializer for nullable ints. + /// + public static IBsonSerializer NullableInt32Instance => __nullableInt32Instance; + + /// + /// Gets a serializer for nullable longs. + /// + public static IBsonSerializer NullableInt64Instance => __nullableInt64Instance; + + /// + /// Gets a serializer for local DateTime. + /// + public static IBsonSerializer NullableLocalDateTimeInstance => __nullableLocalDateTimeInstance; + + /// + /// Gets a serializer for nullable floats. + /// + public static IBsonSerializer NullableSingleInstance => __nullableSingleInstance; + + /// + /// Gets a serializer for nullable ObjectIds. + /// + public static IBsonSerializer NullableObjectIdInstance => __nullableObjectIdInstance; + + /// + /// Gets a serializer for nullable Guids with standard representation. + /// + public static IBsonSerializer NullableStandardGuidInstance => __nullableStandardGuidInstance; + + /// + /// Gets a serializer for UTC DateTime. + /// + public static IBsonSerializer NullableUtcDateTimeInstance => __nullableUtcDateTimeInstance; + /// /// Creates a NullableSerializer. /// diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs index db7618ce677..e80842589d8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs @@ -60,5 +60,17 @@ public static TValue GetConstantValue(this Expression expression, Expres var message = $"Expression must be a constant: {expression} in {containingExpression}."; throw new ExpressionNotSupportedException(message); } + + public static bool IsConvert(this Expression expression, out Expression operand) + { + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression) + { + operand = unaryExpression.Operand; + return true; + } + + operand = null; + return false; + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs index 40e41bbd51c..4b4b5906ca3 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs @@ -64,7 +64,8 @@ private AstStage RenderProjectStage( out IBsonSerializer outputSerializer) { var partiallyEvaluatedOutput = (Expression>)LinqExpressionPreprocessor.Preprocess(_output); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedOutput.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedOutput, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true); var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation); outputSerializer = (IBsonSerializer)projectSerializer; @@ -106,7 +107,8 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)LinqExpressionPreprocessor.Preprocess(_groupBy); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedGroupBy.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -150,7 +152,8 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer, TInput>> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)LinqExpressionPreprocessor.Preprocess(_groupBy); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedGroupBy.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -188,7 +191,8 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)LinqExpressionPreprocessor.Preprocess(_groupBy); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedGroupBy.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar); var groupBySerializer = (IBsonSerializer)groupByTranslation.Serializer; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/BsonTypeExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/BsonTypeExtensions.cs new file mode 100644 index 00000000000..de23e6e9d5e --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/BsonTypeExtensions.cs @@ -0,0 +1,24 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using MongoDB.Bson; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +internal static class BsonTypeExtensions +{ + public static bool IsNumeric(this BsonType bsonType) + => bsonType is BsonType.Decimal128 or BsonType.Double or BsonType.Int32 or BsonType.Int64; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs index 2b5c4a3a012..a2eed8cafd8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs @@ -35,7 +35,8 @@ public static LambdaExpression UnquoteLambdaIfQueryableMethod(MethodInfo method, Ensure.IsNotNull(method, nameof(method)); Ensure.IsNotNull(expression, nameof(expression)); - if (method.DeclaringType == typeof(Queryable)) + var declaringType = method.DeclaringType; + if (declaringType == typeof(Queryable) || declaringType == typeof(MongoQueryable)) { return UnquoteLambda(expression); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/IBsonSerializerExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/IBsonSerializerExtensions.cs new file mode 100644 index 00000000000..cfed4b06026 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/IBsonSerializerExtensions.cs @@ -0,0 +1,165 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +internal static class IBsonSerializerExtensions +{ + public static bool CanBeAssignedTo(this IBsonSerializer sourceSerializer, IBsonSerializer targetSerializer) + { + if (sourceSerializer.Equals(targetSerializer)) + { + return true; + } + + if (sourceSerializer.ValueType.IsNumeric() && + targetSerializer.ValueType.IsNumeric() && + sourceSerializer.HasNumericRepresentation() && + targetSerializer.HasNumericRepresentation()) + { + return true; + } + + if (targetSerializer.ValueType.IsAssignableFrom(sourceSerializer.ValueType)) + { + return true; + } + + return false; + } + + public static IBsonSerializer GetItemSerializer(this IBsonSerializer serializer) + => ArraySerializerHelper.GetItemSerializer(serializer); + + public static IBsonSerializer GetItemSerializer(this IBsonSerializer serializer, int index) + { + if (serializer is IPolymorphicArraySerializer polymorphicArraySerializer) + { + return polymorphicArraySerializer.GetItemSerializer(index); + } + else + { + return serializer.GetItemSerializer(); + } + } + + public static IBsonSerializer GetItemSerializer(this IBsonSerializer serializer, Expression indexExpression, Expression containingExpression) + { + if (serializer is IPolymorphicArraySerializer polymorphicArraySerializer) + { + var index = indexExpression.GetConstantValue(containingExpression); + return polymorphicArraySerializer.GetItemSerializer(index); + } + else + { + return serializer.GetItemSerializer(); + } + } + + public static IReadOnlyList GetMatchingMemberSerializationInfosForConstructorParameters( + this IBsonSerializer serializer, + Expression expression, + ConstructorInfo constructorInfo) + { + if (serializer is not IBsonDocumentSerializer documentSerializer) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer type {serializer.GetType().Name} does not implement IBsonDocumentSerializer"); + } + + var matchingMemberSerializationInfos = new List(); + foreach (var constructorParameter in constructorInfo.GetParameters()) + { + var matchingMemberSerializationInfo = GetMatchingMemberSerializationInfo(expression, documentSerializer, constructorParameter.Name); + matchingMemberSerializationInfos.Add(matchingMemberSerializationInfo); + } + + return matchingMemberSerializationInfos; + + static BsonSerializationInfo GetMatchingMemberSerializationInfo( + Expression expression, + IBsonDocumentSerializer documentSerializer, + string constructorParameterName) + { + var possibleMatchingMembers = documentSerializer.ValueType.GetMembers().Where(m => m.Name.Equals(constructorParameterName, StringComparison.OrdinalIgnoreCase)).ToArray(); + if (possibleMatchingMembers.Length == 0) + { + throw new ExpressionNotSupportedException(expression, because: $"no matching member found for constructor parameter: {constructorParameterName}"); + } + if (possibleMatchingMembers.Length > 1) + { + throw new ExpressionNotSupportedException(expression, because: $"multiple possible matching members found for constructor parameter: {constructorParameterName}"); + } + var matchingMemberName = possibleMatchingMembers[0].Name; + + if (!documentSerializer.TryGetMemberSerializationInfo(matchingMemberName, out var matchingMemberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer of type {documentSerializer.GetType().Name} did not provide serialization info for member {matchingMemberName}"); + } + + return matchingMemberSerializationInfo; + } + } + + public static bool HasNumericRepresentation(this IBsonSerializer serializer) + { + return + serializer is IHasRepresentationSerializer hasRepresentationSerializer && + hasRepresentationSerializer.Representation.IsNumeric(); + } + + public static bool IsKeyValuePairSerializer( + this IBsonSerializer serializer, + out string keyElementName, + out string valueElementName, + out IBsonSerializer keySerializer, + out IBsonSerializer valueSerializer) + { + // TODO: add properties to IKeyValuePairSerializer to let us extract the needed information + // note: we can only verify the existence of "Key" and "Value" properties, but can't verify there are no others + if (serializer.ValueType is var valueType && + valueType.IsConstructedGenericType && + valueType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>) && + serializer is IBsonDocumentSerializer documentSerializer && + documentSerializer.TryGetMemberSerializationInfo("Key", out var keySerializationInfo) && + documentSerializer.TryGetMemberSerializationInfo("Value", out var valueSerializationInfo)) + { + keyElementName = keySerializationInfo.ElementName; + valueElementName = valueSerializationInfo.ElementName; + keySerializer = keySerializationInfo.Serializer; + valueSerializer = valueSerializationInfo.Serializer; + return true; + } + + keyElementName = null; + valueElementName = null; + keySerializer = null; + valueSerializer = null; + return false; + } + + public static IBsonSerializer Unwrapped(this IBsonSerializer serializer) + { + return serializer is IWrappedValueSerializer wrappedValueSerializer ? wrappedValueSerializer.ValueSerializer.Unwrapped() : serializer; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs index f73c074c835..f7e390f461e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs @@ -15,6 +15,7 @@ using System; using System.Linq; +using System.Collections.Generic; using System.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Misc @@ -110,6 +111,37 @@ public static bool IsInstanceCompareToMethod(this MethodInfo method) return false; } + public static bool IsOneOf(this MethodInfo method, HashSet comparands) + { + if (comparands != null) + { + if (method.IsGenericMethod) + { + var methodDefinition = method.GetGenericMethodDefinition(); + return comparands.Contains(methodDefinition); + } + else + { + return comparands.Contains(method); + } + } + + return false; + } + + public static bool IsOneOf(this MethodInfo method, params HashSet[] comparands) + { + for (var i = 0; i < comparands.Length; i++) + { + if (method.IsOneOf(comparands[i])) + { + return true; + } + } + + return false; + } + public static bool IsOneOf(this MethodInfo method, MethodInfo comparand1, MethodInfo comparand2) { return method.Is(comparand1) || method.Is(comparand2); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs index 0b34b0bd7cf..c0c4a473eb2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs @@ -97,7 +97,7 @@ public static BsonType GetRepresentation(IBsonSerializer serializer) return GetRepresentation(downcastingSerializer.DerivedSerializer); } - if (serializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) + if (serializer is IAsEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) { return GetRepresentation(enumUnderlyingTypeSerializer.EnumSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs index 636c616deb3..c34725e3d17 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs @@ -72,7 +72,7 @@ public static Type GetIEnumerableGenericInterface(this Type enumerableType) throw new InvalidOperationException($"Could not find IEnumerable interface of type: {enumerableType}."); } - public static bool Implements(this Type type, Type @interface) + public static bool ImplementsInterface(this Type type, Type @interface) { if (type == @interface) { @@ -102,6 +102,7 @@ public static bool Implements(this Type type, Type @interface) public static bool ImplementsDictionaryInterface(this Type type, out Type keyType, out Type valueType) { + // note: returns true for IReadOnlyDictionary also if (TryGetGenericInterface(type, __dictionaryInterfaceDefinitions, out var dictionaryInterface)) { var genericArguments = dictionaryInterface.GetGenericArguments(); @@ -146,6 +147,30 @@ public static bool ImplementsIList(this Type type, out Type itemType) return false; } + public static bool ImplementsIOrderedEnumerable(this Type type, out Type itemType) + { + if (TryGetIOrderedEnumerableGenericInterface(type, out var iOrderedEnumerableType)) + { + itemType = iOrderedEnumerableType.GetGenericArguments()[0]; + return true; + } + + itemType = null; + return false; + } + + public static bool ImplementsIOrderedQueryable(this Type type, out Type itemType) + { + if (TryGetIOrderedQueryableGenericInterface(type, out var iorderedQueryableType)) + { + itemType = iorderedQueryableType.GetGenericArguments()[0]; + return true; + } + + itemType = null; + return false; + } + public static bool ImplementsIQueryable(this Type type, out Type itemType) { if (TryGetIQueryableGenericInterface(type, out var iqueryableType)) @@ -158,6 +183,25 @@ public static bool ImplementsIQueryable(this Type type, out Type itemType) return false; } + public static bool ImplementsIQueryableOf(this Type type, Type itemType) + { + return + ImplementsIEnumerable(type, out var actualItemType) && + actualItemType == itemType; + } + + public static bool ImplementsISet(this Type type, out Type itemType) + { + if (TryGetISetGenericInterface(type, out var isetType)) + { + itemType = isetType.GetGenericArguments()[0]; + return true; + } + + itemType = null; + return false; + } + public static bool Is(this Type type, Type comparand) { if (type == comparand) @@ -197,11 +241,14 @@ public static bool IsArray(this Type type, out Type itemType) return false; } + public static bool IsBoolean(this Type type) + { + return type == typeof(bool); + } + public static bool IsBooleanOrNullableBoolean(this Type type) { - return - type == typeof(bool) || - type.IsNullable(out var valueType) && valueType == typeof(bool); + return IsBoolean(type) || type.IsNullable(out var valueType) && IsBoolean(valueType); } public static bool IsConvertibleToEnum(this Type type) @@ -294,23 +341,21 @@ public static bool IsNullableOf(this Type type, Type valueType) return type.IsNullable(out var nullableValueType) && nullableValueType == valueType; } - public static bool IsReadOnlySpanOf(this Type type, Type itemType) - { - return - type.IsGenericType && - type.GetGenericTypeDefinition() == typeof(ReadOnlySpan<>) && - type.GetGenericArguments()[0] == itemType; - } - public static bool IsNumeric(this Type type) { - return - type == typeof(int) || - type == typeof(long) || - type == typeof(double) || - type == typeof(float) || - type == typeof(decimal) || - type == typeof(Decimal128); + return Type.GetTypeCode(type) is + TypeCode.Byte or + TypeCode.Char or // TODO: should we really treat char as numeric? + TypeCode.Decimal or + TypeCode.Double or + TypeCode.Int16 or + TypeCode.Int32 or + TypeCode.Int64 or + TypeCode.SByte or + TypeCode.Single or + TypeCode.UInt16 or + TypeCode.UInt32 or + TypeCode.UInt64; } public static bool IsNumericOrNullableNumeric(this Type type) @@ -320,6 +365,14 @@ public static bool IsNumericOrNullableNumeric(this Type type) type.IsNullable(out var valueType) && valueType.IsNumeric(); } + public static bool IsReadOnlySpanOf(this Type type, Type itemType) + { + return + type.IsGenericType && + type.GetGenericTypeDefinition() == typeof(ReadOnlySpan<>) && + type.GetGenericArguments()[0] == itemType; + } + public static bool IsSameAsOrNullableOf(this Type type, Type valueType) { return type == valueType || type.IsNullableOf(valueType); @@ -337,7 +390,7 @@ public static bool IsSubclassOfOrImplements(this Type type, Type baseTypeOrInter { return type.IsSubclassOf(baseTypeOrInterface) || - type.Implements(baseTypeOrInterface); + type.ImplementsInterface(baseTypeOrInterface); } public static bool IsTuple(this Type type) @@ -386,9 +439,18 @@ public static bool TryGetIEnumerableGenericInterface(this Type type, out Type ie public static bool TryGetIListGenericInterface(this Type type, out Type ilistGenericInterface) => TryGetGenericInterface(type, typeof(IList<>), out ilistGenericInterface); + public static bool TryGetIOrderedEnumerableGenericInterface(this Type type, out Type iorderedEnumerableGenericInterface) + => TryGetGenericInterface(type, typeof(IOrderedEnumerable<>), out iorderedEnumerableGenericInterface); + + public static bool TryGetIOrderedQueryableGenericInterface(this Type type, out Type iorderedQueryableGenericInterface) + => TryGetGenericInterface(type, typeof(IOrderedQueryable<>), out iorderedQueryableGenericInterface); + public static bool TryGetIQueryableGenericInterface(this Type type, out Type iqueryableGenericInterface) => TryGetGenericInterface(type, typeof(IQueryable<>), out iqueryableGenericInterface); + public static bool TryGetISetGenericInterface(this Type type, out Type isetGenericInterface) + => TryGetGenericInterface(type, typeof(ISet<>), out isetGenericInterface); + private static TValue GetDefaultValueGeneric() { return default(TValue); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs index fe96bacae36..a868f3a8508 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs @@ -41,7 +41,7 @@ internal class MongoQuery : MongoQuery, IOrderedQue public MongoQuery(MongoQueryProvider provider) { _provider = provider; - _expression = Expression.Constant(this); + _expression = Expression.Constant(this, typeof(IQueryable<>).MakeGenericType(typeof(TDocument))); } public MongoQuery(MongoQueryProvider provider, Expression expression) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/BsonDocumentMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/BsonDocumentMethod.cs index fd12fccd4c9..1e4e3b582c4 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/BsonDocumentMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/BsonDocumentMethod.cs @@ -22,14 +22,20 @@ internal static class BsonDocumentMethod { // private static fields private static readonly MethodInfo __addWithNameAndValue; + private static readonly MethodInfo __getItemWithIndex; + private static readonly MethodInfo __getItemWithName; // static constructor static BsonDocumentMethod() { __addWithNameAndValue = ReflectionInfo.Method((BsonDocument document, string name, BsonValue value) => document.Add(name, value)); + __getItemWithIndex = ReflectionInfo.Method((BsonDocument document, int index) => document[index]); + __getItemWithName = ReflectionInfo.Method((BsonDocument document, string name) => document[name]); } // public static properties public static MethodInfo AddWithNameAndValue => __addWithNameAndValue; + public static MethodInfo GetItemWithIndex => __getItemWithIndex; + public static MethodInfo GetItemWithName => __getItemWithName; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DateTimeMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DateTimeMethod.cs index 4e677ffe5f3..bcc67a291eb 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DateTimeMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DateTimeMethod.cs @@ -14,6 +14,7 @@ */ using System; +using System.Collections.Generic; using System.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection @@ -61,6 +62,15 @@ internal static class DateTimeMethod private static readonly MethodInfo __week; private static readonly MethodInfo __weekWithTimezone; + // sets of methods + private static readonly HashSet __addOrSubtractOverloads; + private static readonly HashSet __addOrSubtractWithTimeSpanOverloads; + private static readonly HashSet __addOrSubtractWithTimezoneOverloads; + private static readonly HashSet __addOrSubtractWithUnitOverloads; + private static readonly HashSet __subtractReturningDateTimeOverloads; + private static readonly HashSet __subtractReturningInt64Overloads; + private static readonly HashSet __subtractReturningTimeSpanWithMillisecondsUnitsOverloads; + // static constructor static DateTimeMethod() { @@ -103,6 +113,91 @@ static DateTimeMethod() __truncateWithBinSizeAndTimezone = ReflectionInfo.Method((DateTime @this, DateTimeUnit unit, long binSize, string timezone) => @this.Truncate(unit, binSize, timezone)); __week = ReflectionInfo.Method((DateTime @this) => @this.Week()); __weekWithTimezone = ReflectionInfo.Method((DateTime @this, string timezone) => @this.Week(timezone)); + + // initialize sets of methods after individual methods + __addOrSubtractOverloads = + [ + __add, + __addDays, + __addDaysWithTimezone, + __addHours, + __addHoursWithTimezone, + __addMilliseconds, + __addMillisecondsWithTimezone, + __addMinutes, + __addMinutesWithTimezone, + __addMonths, + __addMonthsWithTimezone, + __addQuarters, + __addQuartersWithTimezone, + __addSeconds, + __addSecondsWithTimezone, + __addTicks, + __addWeeks, + __addWeeksWithTimezone, + __addWithTimezone, + __addWithUnit, + __addWithUnitAndTimezone, + __addYears, + __addYearsWithTimezone, + __subtractWithTimeSpan, + __subtractWithTimeSpanAndTimezone, + __subtractWithUnit, + __subtractWithUnitAndTimezone + ]; + + __addOrSubtractWithTimeSpanOverloads = + [ + __add, + __addWithTimezone, + __subtractWithTimeSpan, + __subtractWithTimeSpanAndTimezone + ]; + + __addOrSubtractWithTimezoneOverloads = + [ + __addDaysWithTimezone, + __addHoursWithTimezone, + __addMillisecondsWithTimezone, + __addMinutesWithTimezone, + __addMonthsWithTimezone, + __addQuartersWithTimezone, + __addSecondsWithTimezone, + __addWeeksWithTimezone, + __addWithTimezone, + __addWithUnitAndTimezone, + __addYearsWithTimezone, + __subtractWithTimeSpanAndTimezone, + __subtractWithUnitAndTimezone + ]; + + __addOrSubtractWithUnitOverloads = + [ + __addWithUnit, + __addWithUnitAndTimezone, + __subtractWithUnit, + __subtractWithUnitAndTimezone + ]; + + __subtractReturningDateTimeOverloads = + [ + __subtractWithTimeSpan, + __subtractWithTimeSpanAndTimezone, + __subtractWithUnit, + __subtractWithUnitAndTimezone + ]; + + __subtractReturningInt64Overloads = + [ + __subtractWithDateTimeAndUnit, + __subtractWithDateTimeAndUnitAndTimezone + ]; + + __subtractReturningTimeSpanWithMillisecondsUnitsOverloads = + [ + __subtractWithDateTime, + __subtractWithDateTimeAndTimezone + ]; } // public properties @@ -145,5 +240,14 @@ static DateTimeMethod() public static MethodInfo TruncateWithBinSizeAndTimezone => __truncateWithBinSizeAndTimezone; public static MethodInfo Week => __week; public static MethodInfo WeekWithTimezone => __weekWithTimezone; + + // sets of methods + public static HashSet AddOrSubtractOverloads => __addOrSubtractOverloads; + public static HashSet AddOrSubtractWithTimeSpanOverloads => __addOrSubtractWithTimeSpanOverloads; + public static HashSet AddOrSubtractWithTimezoneOverloads => __addOrSubtractWithTimezoneOverloads; + public static HashSet AddOrSubtractWithUnitOverloads => __addOrSubtractWithUnitOverloads; + public static HashSet SubtractReturningDateTimeOverloads => __subtractReturningDateTimeOverloads; + public static HashSet SubtractReturningInt64Overloads => __subtractReturningInt64Overloads; + public static HashSet SubtractReturningTimeSpanWithMillisecondsUnitsOverloads => __subtractReturningTimeSpanWithMillisecondsUnitsOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryMethod.cs index dd245eb6c3c..665a4ac8548 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryMethod.cs @@ -21,6 +21,18 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection internal static class DictionaryMethod { // public static methods + public static bool IsContainsKeyMethod(MethodInfo method) + { + return + !method.IsStatic && + method.Name == "ContainsKey" && + method.DeclaringType.ImplementsDictionaryInterface(out var keyType, out _) && + method.GetParameters() is var parameters && + parameters.Length == 1 && + parameters[0].ParameterType == keyType && + method.ReturnType == typeof(bool); + } + public static bool IsGetItemWithKeyMethod(MethodInfo method) { return diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs index a10f2a67531..eb594479353 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs @@ -30,6 +30,7 @@ internal static class EnumerableMethod private static readonly MethodInfo __aggregateWithSeedAndFunc; private static readonly MethodInfo __aggregateWithSeedFuncAndResultSelector; private static readonly MethodInfo __all; + private static readonly MethodInfo __allWithPredicate; private static readonly MethodInfo __any; private static readonly MethodInfo __anyWithPredicate; private static readonly MethodInfo __append; @@ -74,7 +75,7 @@ internal static class EnumerableMethod private static readonly MethodInfo __firstOrDefault; private static readonly MethodInfo __firstOrDefaultWithPredicate; private static readonly MethodInfo __firstWithPredicate; - private static readonly MethodInfo __groupBy; + private static readonly MethodInfo __groupByWithKeySelector; private static readonly MethodInfo __groupByWithKeySelectorAndElementSelector; private static readonly MethodInfo __groupByWithKeySelectorAndResultSelector; private static readonly MethodInfo __groupByWithKeySelectorElementSelectorAndResultSelector; @@ -145,7 +146,7 @@ internal static class EnumerableMethod private static readonly MethodInfo __repeat; private static readonly MethodInfo __reverse; private static readonly MethodInfo __select; - private static readonly MethodInfo __selectMany; + private static readonly MethodInfo __selectManyWithSelector; private static readonly MethodInfo __selectManyWithCollectionSelectorAndResultSelector; private static readonly MethodInfo __selectManyWithCollectionSelectorTakingIndexAndResultSelector; private static readonly MethodInfo __selectManyWithSelectorTakingIndex; @@ -192,6 +193,11 @@ internal static class EnumerableMethod private static readonly MethodInfo __whereWithPredicateTakingIndex; private static readonly MethodInfo __zip; + // sets of methods + private static readonly HashSet __pickOverloads; + private static readonly HashSet __pickWithComputedNOverloads; + private static readonly HashSet __pickWithSortDefinitionOverloads; + // static constructor static EnumerableMethod() { @@ -199,6 +205,7 @@ static EnumerableMethod() __aggregateWithSeedAndFunc = ReflectionInfo.Method((IEnumerable source, object seed, Func func) => source.Aggregate(seed, func)); __aggregateWithSeedFuncAndResultSelector = ReflectionInfo.Method((IEnumerable source, object seed, Func func, Func resultSelector) => source.Aggregate(seed, func, resultSelector)); __all = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.All(predicate)); + __allWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.All(predicate)); __any = ReflectionInfo.Method((IEnumerable source) => source.Any()); __anyWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.Any(predicate)); __append = ReflectionInfo.Method((IEnumerable source, object element) => source.Append(element)); @@ -243,7 +250,7 @@ static EnumerableMethod() __firstOrDefault = ReflectionInfo.Method((IEnumerable source) => source.FirstOrDefault()); __firstOrDefaultWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.FirstOrDefault(predicate)); __firstWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.First(predicate)); - __groupBy = ReflectionInfo.Method((IEnumerable source, Func keySelector) => source.GroupBy(keySelector)); + __groupByWithKeySelector = ReflectionInfo.Method((IEnumerable source, Func keySelector) => source.GroupBy(keySelector)); __groupByWithKeySelectorAndElementSelector = ReflectionInfo.Method((IEnumerable source, Func keySelector, Func elementSelector) => source.GroupBy(keySelector, elementSelector)); __groupByWithKeySelectorAndResultSelector = ReflectionInfo.Method((IEnumerable source, Func keySelector, Func resultSelector) => source.GroupBy(keySelector, resultSelector)); __groupByWithKeySelectorElementSelectorAndResultSelector = ReflectionInfo.Method((IEnumerable source, Func keySelector, Func elementSelector, Func, object> resultSelector) => source.GroupBy(keySelector, elementSelector, resultSelector)); @@ -314,7 +321,7 @@ static EnumerableMethod() __repeat = ReflectionInfo.Method((object element, int count) => Enumerable.Repeat(element, count)); __reverse = ReflectionInfo.Method((IEnumerable source) => source.Reverse()); __select = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Select(selector)); - __selectMany = ReflectionInfo.Method((IEnumerable source, Func> selector) => source.SelectMany(selector)); + __selectManyWithSelector = ReflectionInfo.Method((IEnumerable source, Func> selector) => source.SelectMany(selector)); __selectManyWithCollectionSelectorAndResultSelector = ReflectionInfo.Method((IEnumerable source, Func> collectionSelector, Func resultSelector) => source.SelectMany(collectionSelector, resultSelector)); __selectManyWithCollectionSelectorTakingIndexAndResultSelector = ReflectionInfo.Method((IEnumerable source, Func> collectionSelector, Func resultSelector) => source.SelectMany(collectionSelector, resultSelector)); __selectManyWithSelectorTakingIndex = ReflectionInfo.Method((IEnumerable source, Func> selector) => source.SelectMany(selector)); @@ -360,6 +367,45 @@ static EnumerableMethod() __where = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.Where(predicate)); __whereWithPredicateTakingIndex = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.Where(predicate)); __zip = ReflectionInfo.Method((IEnumerable first, IEnumerable second, Func resultSelector) => first.Zip(second, resultSelector)); + + // initialize sets of methods after individual methods + __pickOverloads = + [ + __bottom, + __bottomN, + __bottomNWithComputedN, + __firstN, + __firstNWithComputedN, + __lastN, + __lastNWithComputedN, + __maxN, + __maxNWithComputedN, + __minN, + __minNWithComputedN, + __top, + __topN, + __topNWithComputedN + ]; + + __pickWithComputedNOverloads = + [ + __bottomNWithComputedN, + __firstNWithComputedN, + __lastNWithComputedN, + __maxNWithComputedN, + __minNWithComputedN, + __topNWithComputedN + ]; + + __pickWithSortDefinitionOverloads = + [ + __bottom, + __bottomN, + __bottomNWithComputedN, + __top, + __topN, + __topNWithComputedN + ]; } // public properties @@ -367,6 +413,7 @@ static EnumerableMethod() public static MethodInfo AggregateWithSeedAndFunc => __aggregateWithSeedAndFunc; public static MethodInfo AggregateWithSeedFuncAndResultSelector => __aggregateWithSeedFuncAndResultSelector; public static MethodInfo All => __all; + public static MethodInfo AllWithPredicate => __allWithPredicate; public static MethodInfo Any => __any; public static MethodInfo AnyWithPredicate => __anyWithPredicate; public static MethodInfo Append => __append; @@ -411,7 +458,7 @@ static EnumerableMethod() public static MethodInfo FirstOrDefault => __firstOrDefault; public static MethodInfo FirstOrDefaultWithPredicate => __firstOrDefaultWithPredicate; public static MethodInfo FirstWithPredicate => __firstWithPredicate; - public static MethodInfo GroupBy => __groupBy; + public static MethodInfo GroupByWithKeySelector => __groupByWithKeySelector; public static MethodInfo GroupByWithKeySelectorAndElementSelector => __groupByWithKeySelectorAndElementSelector; public static MethodInfo GroupByWithKeySelectorAndResultSelector => __groupByWithKeySelectorAndResultSelector; public static MethodInfo GroupByWithKeySelectorElementSelectorAndResultSelector => __groupByWithKeySelectorElementSelectorAndResultSelector; @@ -482,7 +529,7 @@ static EnumerableMethod() public static MethodInfo Repeat => __repeat; public static MethodInfo Reverse => __reverse; public static MethodInfo Select => __select; - public static MethodInfo SelectMany => __selectMany; + public static MethodInfo SelectManyWithSelector => __selectManyWithSelector; public static MethodInfo SelectManyWithCollectionSelectorAndResultSelector => __selectManyWithCollectionSelectorAndResultSelector; public static MethodInfo SelectManyWithCollectionSelectorTakingIndexAndResultSelector => __selectManyWithCollectionSelectorTakingIndexAndResultSelector; public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex; @@ -529,6 +576,11 @@ static EnumerableMethod() public static MethodInfo WhereWithPredicateTakingIndex => __whereWithPredicateTakingIndex; public static MethodInfo Zip => __zip; + // sets of methods + public static HashSet PickOverloads => __pickOverloads; + public static HashSet PickWithComputedNOverloads => __pickWithComputedNOverloads; + public static HashSet PickWithSortDefinitionOverloads => __pickWithSortDefinitionOverloads; + // public methods public static bool IsContainsMethod(MethodCallExpression methodCallExpression, out Expression sourceExpression, out Expression valueExpression) { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableOrQueryableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableOrQueryableMethod.cs new file mode 100644 index 00000000000..2db1ad1b0e8 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableOrQueryableMethod.cs @@ -0,0 +1,721 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +internal static class EnumerableOrQueryableMethod +{ + private static readonly HashSet __aggregateOverloads; + private static readonly HashSet __aggregateWithFunc; + private static readonly HashSet __aggregateWithSeedOverloads; + private static readonly HashSet __aggregateWithSeedAndFunc; + private static readonly HashSet __aggregateWithSeedFuncAndResultSelector; + private static readonly HashSet __all; + private static readonly HashSet __anyOverloads; + private static readonly HashSet __any; + private static readonly HashSet __anyWithPredicate; + private static readonly HashSet __append; + private static readonly HashSet __appendOrPrepend; + private static readonly HashSet __averageOverloads; + private static readonly HashSet __averageWithSelectorOverloads; + private static readonly HashSet __concat; + private static readonly HashSet __countOverloads; + private static readonly HashSet __countWithPredicateOverloads; + private static readonly HashSet __distinct; + private static readonly HashSet __elementAt; + private static readonly HashSet __elementAtOrDefault; + private static readonly HashSet __elementAtOverloads; + private static readonly HashSet __except; + private static readonly HashSet __firstOverloads; + private static readonly HashSet __firstOrDefaultOverloads; + private static readonly HashSet __firstWithPredicateOverloads; + private static readonly HashSet __groupByOverloads; + private static readonly HashSet __groupByWithKeySelector; + private static readonly HashSet __groupByWithKeySelectorAndElementSelector; + private static readonly HashSet __groupByWithKeySelectorAndResultSelector; + private static readonly HashSet __groupByWithKeySelectorElementSelectorAndResultSelector; + private static readonly HashSet __lastOverloads; + private static readonly HashSet __lastOrDefaultOverloads; + private static readonly HashSet __lastWithPredicateOverloads; + private static readonly HashSet __maxOverloads; + private static readonly HashSet __maxWithSelectorOverloads; + private static readonly HashSet __minOverloads; + private static readonly HashSet __minWithSelectorOverloads; + private static readonly HashSet __selectManyOverloads; + private static readonly HashSet __selectManyWithCollectionSelectorAndResultSelector; + private static readonly HashSet __selectManyWithSelector; + private static readonly HashSet __singleOverloads; + private static readonly HashSet __singleWithPredicateOverloads; + private static readonly HashSet __skipOverloads; + private static readonly HashSet __skipWhile; + private static readonly HashSet __sumOverloads; + private static readonly HashSet __sumWithSelectorOverloads; + private static readonly HashSet __takeOverloads; + private static readonly HashSet __takeWhile; + private static readonly HashSet __where; + + // sets of methods + private static readonly HashSet[] __firstOrLastOverloads; + private static readonly HashSet[] __firstOrLastWithPredicateOverloads; + private static readonly HashSet[] __firstOrLastOrSingleOverloads; + private static readonly HashSet[] __firstOrLastOrSingleWithPredicateOverloads; + private static readonly HashSet[] __maxOrMinOverloads; + private static readonly HashSet[] __maxOrMinWithSelectorOverloads; + private static readonly HashSet[] __skipOrTakeOverloads; + private static readonly HashSet[] __skipOrTakeWhile; + + static EnumerableOrQueryableMethod() + { + __aggregateOverloads = + [ + EnumerableMethod.AggregateWithFunc, + EnumerableMethod.AggregateWithSeedAndFunc, + EnumerableMethod.AggregateWithSeedFuncAndResultSelector, + QueryableMethod.AggregateWithFunc, + QueryableMethod.AggregateWithSeedAndFunc, + QueryableMethod.AggregateWithSeedFuncAndResultSelector + ]; + + __aggregateWithFunc = + [ + EnumerableMethod.AggregateWithFunc, + QueryableMethod.AggregateWithFunc + ]; + + __aggregateWithSeedOverloads = + [ + EnumerableMethod.AggregateWithSeedAndFunc, + EnumerableMethod.AggregateWithSeedFuncAndResultSelector, + QueryableMethod.AggregateWithSeedAndFunc, + QueryableMethod.AggregateWithSeedFuncAndResultSelector + ]; + + __aggregateWithSeedAndFunc = + [ + EnumerableMethod.AggregateWithSeedAndFunc, + QueryableMethod.AggregateWithSeedAndFunc + ]; + + __aggregateWithSeedFuncAndResultSelector = + [ + EnumerableMethod.AggregateWithSeedFuncAndResultSelector, + QueryableMethod.AggregateWithSeedFuncAndResultSelector + ]; + + __all = + [ + EnumerableMethod.All, + QueryableMethod.All + ]; + + __any = + [ + EnumerableMethod.Any, + QueryableMethod.Any, + ]; + + __anyOverloads = + [ + EnumerableMethod.Any, + EnumerableMethod.AnyWithPredicate, + QueryableMethod.Any, + QueryableMethod.AnyWithPredicate + ]; + + __anyWithPredicate = + [ + EnumerableMethod.AnyWithPredicate, + QueryableMethod.AnyWithPredicate + ]; + + __append = + [ + EnumerableMethod.Append, + QueryableMethod.Append + ]; + + __appendOrPrepend = + [ + EnumerableMethod.Append, + EnumerableMethod.Prepend, + QueryableMethod.Append, + QueryableMethod.Prepend + ]; + + __averageOverloads = + [ + EnumerableMethod.AverageDecimal, + EnumerableMethod.AverageDecimalWithSelector, + EnumerableMethod.AverageDouble, + EnumerableMethod.AverageDoubleWithSelector, + EnumerableMethod.AverageInt32, + EnumerableMethod.AverageInt32WithSelector, + EnumerableMethod.AverageInt64, + EnumerableMethod.AverageInt64WithSelector, + EnumerableMethod.AverageNullableDecimal, + EnumerableMethod.AverageNullableDecimalWithSelector, + EnumerableMethod.AverageNullableDouble, + EnumerableMethod.AverageNullableDoubleWithSelector, + EnumerableMethod.AverageNullableInt32, + EnumerableMethod.AverageNullableInt32WithSelector, + EnumerableMethod.AverageNullableInt64, + EnumerableMethod.AverageNullableInt64WithSelector, + EnumerableMethod.AverageNullableSingle, + EnumerableMethod.AverageNullableSingleWithSelector, + EnumerableMethod.AverageSingle, + EnumerableMethod.AverageSingleWithSelector, + QueryableMethod.AverageDecimal, + QueryableMethod.AverageDecimalWithSelector, + QueryableMethod.AverageDouble, + QueryableMethod.AverageDoubleWithSelector, + QueryableMethod.AverageInt32, + QueryableMethod.AverageInt32WithSelector, + QueryableMethod.AverageInt64, + QueryableMethod.AverageInt64WithSelector, + QueryableMethod.AverageNullableDecimal, + QueryableMethod.AverageNullableDecimalWithSelector, + QueryableMethod.AverageNullableDouble, + QueryableMethod.AverageNullableDoubleWithSelector, + QueryableMethod.AverageNullableInt32, + QueryableMethod.AverageNullableInt32WithSelector, + QueryableMethod.AverageNullableInt64, + QueryableMethod.AverageNullableInt64WithSelector, + QueryableMethod.AverageNullableSingle, + QueryableMethod.AverageNullableSingleWithSelector, + QueryableMethod.AverageSingle, + QueryableMethod.AverageSingleWithSelector + ]; + + __averageWithSelectorOverloads = + [ + EnumerableMethod.AverageDecimalWithSelector, + EnumerableMethod.AverageDoubleWithSelector, + EnumerableMethod.AverageInt32WithSelector, + EnumerableMethod.AverageInt64WithSelector, + EnumerableMethod.AverageNullableDecimalWithSelector, + EnumerableMethod.AverageNullableDoubleWithSelector, + EnumerableMethod.AverageNullableInt32WithSelector, + EnumerableMethod.AverageNullableInt64WithSelector, + EnumerableMethod.AverageNullableSingleWithSelector, + EnumerableMethod.AverageSingleWithSelector, + QueryableMethod.AverageDecimalWithSelector, + QueryableMethod.AverageDoubleWithSelector, + QueryableMethod.AverageInt32WithSelector, + QueryableMethod.AverageInt64WithSelector, + QueryableMethod.AverageNullableDecimalWithSelector, + QueryableMethod.AverageNullableDoubleWithSelector, + QueryableMethod.AverageNullableInt32WithSelector, + QueryableMethod.AverageNullableInt64WithSelector, + QueryableMethod.AverageNullableSingleWithSelector, + QueryableMethod.AverageSingleWithSelector, + ]; + + __concat = + [ + EnumerableMethod.Concat, + QueryableMethod.Concat + ]; + + __countOverloads = + [ + EnumerableMethod.Count, + EnumerableMethod.CountWithPredicate, + EnumerableMethod.LongCount, // it's convenient to treat LongCount as if it was an overload + EnumerableMethod.LongCountWithPredicate, + QueryableMethod.Count, + QueryableMethod.CountWithPredicate, + QueryableMethod.LongCount, + QueryableMethod.LongCountWithPredicate + ]; + + __countWithPredicateOverloads = + [ + EnumerableMethod.CountWithPredicate, + EnumerableMethod.LongCountWithPredicate, + QueryableMethod.CountWithPredicate, + QueryableMethod.LongCountWithPredicate + ]; + + __distinct = + [ + EnumerableMethod.Distinct, + QueryableMethod.Distinct + ]; + + __elementAt = + [ + EnumerableMethod.ElementAt, + QueryableMethod.ElementAt + ]; + + __elementAtOverloads = + [ + EnumerableMethod.ElementAt, + EnumerableMethod.ElementAtOrDefault, // it's convenient to treat ElementAtOrDefault as if it was an overload + QueryableMethod.ElementAt, + QueryableMethod.ElementAtOrDefault + ]; + + __elementAtOrDefault = + [ + EnumerableMethod.ElementAtOrDefault, + QueryableMethod.ElementAtOrDefault + ]; + + __except = + [ + EnumerableMethod.Except, + QueryableMethod.Except + ]; + + __firstOverloads = + [ + EnumerableMethod.First, + EnumerableMethod.FirstOrDefault, // it's convenient to treat FirstOrDefault as if it was an overload + EnumerableMethod.FirstOrDefaultWithPredicate, + EnumerableMethod.FirstWithPredicate, + QueryableMethod.First, + QueryableMethod.FirstOrDefault, + QueryableMethod.FirstOrDefaultWithPredicate, + QueryableMethod.FirstWithPredicate + ]; + + __firstOrDefaultOverloads = + [ + EnumerableMethod.FirstOrDefault, + EnumerableMethod.FirstOrDefaultWithPredicate, + QueryableMethod.FirstOrDefault, + QueryableMethod.FirstOrDefaultWithPredicate + ]; + + __firstWithPredicateOverloads = + [ + EnumerableMethod.FirstOrDefaultWithPredicate, + EnumerableMethod.FirstWithPredicate, + QueryableMethod.FirstOrDefaultWithPredicate, + QueryableMethod.FirstWithPredicate + ]; + + __groupByOverloads = + [ + EnumerableMethod.GroupByWithKeySelector, + EnumerableMethod.GroupByWithKeySelectorAndElementSelector, + EnumerableMethod.GroupByWithKeySelectorAndResultSelector, + EnumerableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector, + QueryableMethod.GroupByWithKeySelector, + QueryableMethod.GroupByWithKeySelectorAndElementSelector, + QueryableMethod.GroupByWithKeySelectorAndResultSelector, + QueryableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector + ]; + + __groupByWithKeySelector = + [ + EnumerableMethod.GroupByWithKeySelector, + QueryableMethod.GroupByWithKeySelector + ]; + + __groupByWithKeySelectorAndElementSelector = + [ + EnumerableMethod.GroupByWithKeySelectorAndElementSelector, + QueryableMethod.GroupByWithKeySelectorAndElementSelector + ]; + + __groupByWithKeySelectorAndResultSelector = + [ + EnumerableMethod.GroupByWithKeySelectorAndResultSelector, + QueryableMethod.GroupByWithKeySelectorAndResultSelector + ]; + + __groupByWithKeySelectorElementSelectorAndResultSelector = + [ + EnumerableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector, + QueryableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector + ]; + + __lastOverloads = + [ + EnumerableMethod.Last, + EnumerableMethod.LastOrDefault, // it's convenient to treat LastOrDefault as if it was an overload + EnumerableMethod.LastOrDefaultWithPredicate, + EnumerableMethod.LastWithPredicate, + QueryableMethod.Last, + QueryableMethod.LastOrDefault, + QueryableMethod.LastOrDefaultWithPredicate, + QueryableMethod.LastWithPredicate + ]; + + __lastOrDefaultOverloads = + [ + EnumerableMethod.LastOrDefault, + EnumerableMethod.LastOrDefaultWithPredicate, + QueryableMethod.LastOrDefault, + QueryableMethod.LastOrDefaultWithPredicate + ]; + + __lastWithPredicateOverloads = + [ + EnumerableMethod.LastOrDefaultWithPredicate, + EnumerableMethod.LastWithPredicate, + QueryableMethod.LastOrDefaultWithPredicate, + QueryableMethod.LastWithPredicate + ]; + + __maxOverloads = + [ + EnumerableMethod.Max, + EnumerableMethod.MaxDecimal, + EnumerableMethod.MaxDecimalWithSelector, + EnumerableMethod.MaxDouble, + EnumerableMethod.MaxDoubleWithSelector, + EnumerableMethod.MaxInt32, + EnumerableMethod.MaxInt32WithSelector, + EnumerableMethod.MaxInt64, + EnumerableMethod.MaxInt64WithSelector, + EnumerableMethod.MaxNullableDecimal, + EnumerableMethod.MaxNullableDecimalWithSelector, + EnumerableMethod.MaxNullableDouble, + EnumerableMethod.MaxNullableDoubleWithSelector, + EnumerableMethod.MaxNullableInt32, + EnumerableMethod.MaxNullableInt32WithSelector, + EnumerableMethod.MaxNullableInt64, + EnumerableMethod.MaxNullableInt64WithSelector, + EnumerableMethod.MaxNullableSingle, + EnumerableMethod.MaxNullableSingleWithSelector, + EnumerableMethod.MaxSingle, + EnumerableMethod.MaxSingleWithSelector, + EnumerableMethod.MaxWithSelector, + QueryableMethod.Max, + QueryableMethod.MaxWithSelector, + ]; + + __maxWithSelectorOverloads = + [ + EnumerableMethod.MaxDecimalWithSelector, + EnumerableMethod.MaxDoubleWithSelector, + EnumerableMethod.MaxInt32WithSelector, + EnumerableMethod.MaxInt64WithSelector, + EnumerableMethod.MaxNullableDecimalWithSelector, + EnumerableMethod.MaxNullableDoubleWithSelector, + EnumerableMethod.MaxNullableInt32WithSelector, + EnumerableMethod.MaxNullableInt64WithSelector, + EnumerableMethod.MaxNullableSingleWithSelector, + EnumerableMethod.MaxSingleWithSelector, + EnumerableMethod.MaxWithSelector, + QueryableMethod.MaxWithSelector, + QueryableMethod.MinWithSelector, + ]; + + __minOverloads = + [ + EnumerableMethod.Min, + EnumerableMethod.MinDecimal, + EnumerableMethod.MinDecimalWithSelector, + EnumerableMethod.MinDouble, + EnumerableMethod.MinDoubleWithSelector, + EnumerableMethod.MinInt32, + EnumerableMethod.MinInt32WithSelector, + EnumerableMethod.MinInt64, + EnumerableMethod.MinInt64WithSelector, + EnumerableMethod.MinNullableDecimal, + EnumerableMethod.MinNullableDecimalWithSelector, + EnumerableMethod.MinNullableDouble, + EnumerableMethod.MinNullableDoubleWithSelector, + EnumerableMethod.MinNullableInt32, + EnumerableMethod.MinNullableInt32WithSelector, + EnumerableMethod.MinNullableInt64, + EnumerableMethod.MinNullableInt64WithSelector, + EnumerableMethod.MinNullableSingle, + EnumerableMethod.MinNullableSingleWithSelector, + EnumerableMethod.MinSingle, + EnumerableMethod.MinSingleWithSelector, + EnumerableMethod.MinWithSelector, + QueryableMethod.Min, + QueryableMethod.MinWithSelector, + ]; + + __minWithSelectorOverloads = + [ + EnumerableMethod.MinDecimalWithSelector, + EnumerableMethod.MinDoubleWithSelector, + EnumerableMethod.MinInt32WithSelector, + EnumerableMethod.MinInt64WithSelector, + EnumerableMethod.MinNullableDecimalWithSelector, + EnumerableMethod.MinNullableDoubleWithSelector, + EnumerableMethod.MinNullableInt32WithSelector, + EnumerableMethod.MinNullableInt64WithSelector, + EnumerableMethod.MinNullableSingleWithSelector, + EnumerableMethod.MinSingleWithSelector, + EnumerableMethod.MinWithSelector, + ]; + + __selectManyOverloads = + [ + EnumerableMethod.SelectManyWithSelector, + EnumerableMethod.SelectManyWithCollectionSelectorAndResultSelector, + QueryableMethod.SelectManyWithSelector, + QueryableMethod.SelectManyWithCollectionSelectorAndResultSelector + ]; + + __selectManyWithCollectionSelectorAndResultSelector = + [ + EnumerableMethod.SelectManyWithCollectionSelectorAndResultSelector, + QueryableMethod.SelectManyWithCollectionSelectorAndResultSelector + ]; + + __selectManyWithSelector = + [ + EnumerableMethod.SelectManyWithSelector, + QueryableMethod.SelectManyWithSelector + ]; + + __singleOverloads = + [ + EnumerableMethod.Single, + EnumerableMethod.SingleOrDefault, // it's convenient to treat SingleOrDefault as if it was an overload + EnumerableMethod.SingleOrDefaultWithPredicate, + EnumerableMethod.SingleWithPredicate, + QueryableMethod.Single, + QueryableMethod.SingleOrDefault, + QueryableMethod.SingleOrDefaultWithPredicate, + QueryableMethod.SingleWithPredicate + ]; + + __singleWithPredicateOverloads = + [ + EnumerableMethod.SingleOrDefaultWithPredicate, + EnumerableMethod.SingleWithPredicate, + QueryableMethod.SingleOrDefaultWithPredicate, + QueryableMethod.SingleWithPredicate + ]; + + __skipOverloads = + [ + EnumerableMethod.Skip, + EnumerableMethod.SkipWhile, // it's convenient to treat SkipWhile as if it was an overload + QueryableMethod.Skip, + QueryableMethod.SkipWhile, + MongoQueryableMethod.SkipWithLong // it's convenient to group our custom Skip method with the EnumerableOrQueryable Skip methods + ]; + + __skipWhile = + [ + EnumerableMethod.SkipWhile, + QueryableMethod.SkipWhile + ]; + + __sumOverloads = + [ + EnumerableMethod.SumDecimal, + EnumerableMethod.SumDecimalWithSelector, + EnumerableMethod.SumDouble, + EnumerableMethod.SumDoubleWithSelector, + EnumerableMethod.SumInt32, + EnumerableMethod.SumInt32WithSelector, + EnumerableMethod.SumInt64, + EnumerableMethod.SumInt64WithSelector, + EnumerableMethod.SumNullableDecimal, + EnumerableMethod.SumNullableDecimalWithSelector, + EnumerableMethod.SumNullableDouble, + EnumerableMethod.SumNullableDoubleWithSelector, + EnumerableMethod.SumNullableInt32, + EnumerableMethod.SumNullableInt32WithSelector, + EnumerableMethod.SumNullableInt64, + EnumerableMethod.SumNullableInt64WithSelector, + EnumerableMethod.SumNullableSingle, + EnumerableMethod.SumNullableSingleWithSelector, + EnumerableMethod.SumSingle, + EnumerableMethod.SumSingleWithSelector, + QueryableMethod.SumDecimal, + QueryableMethod.SumDecimalWithSelector, + QueryableMethod.SumDouble, + QueryableMethod.SumDoubleWithSelector, + QueryableMethod.SumInt32, + QueryableMethod.SumInt32WithSelector, + QueryableMethod.SumInt64, + QueryableMethod.SumInt64WithSelector, + QueryableMethod.SumNullableDecimal, + QueryableMethod.SumNullableDecimalWithSelector, + QueryableMethod.SumNullableDouble, + QueryableMethod.SumNullableDoubleWithSelector, + QueryableMethod.SumNullableInt32, + QueryableMethod.SumNullableInt32WithSelector, + QueryableMethod.SumNullableInt64, + QueryableMethod.SumNullableInt64WithSelector, + QueryableMethod.SumNullableSingle, + QueryableMethod.SumNullableSingleWithSelector, + QueryableMethod.SumSingle, + QueryableMethod.SumSingleWithSelector + ]; + + __sumWithSelectorOverloads = + [ + EnumerableMethod.SumDecimalWithSelector, + EnumerableMethod.SumDoubleWithSelector, + EnumerableMethod.SumInt32WithSelector, + EnumerableMethod.SumInt64WithSelector, + EnumerableMethod.SumNullableDecimalWithSelector, + EnumerableMethod.SumNullableDoubleWithSelector, + EnumerableMethod.SumNullableInt32WithSelector, + EnumerableMethod.SumNullableInt64WithSelector, + EnumerableMethod.SumNullableSingleWithSelector, + EnumerableMethod.SumSingleWithSelector, + QueryableMethod.SumDecimalWithSelector, + QueryableMethod.SumDoubleWithSelector, + QueryableMethod.SumInt32WithSelector, + QueryableMethod.SumInt64WithSelector, + QueryableMethod.SumNullableDecimalWithSelector, + QueryableMethod.SumNullableDoubleWithSelector, + QueryableMethod.SumNullableInt32WithSelector, + QueryableMethod.SumNullableInt64WithSelector, + QueryableMethod.SumNullableSingleWithSelector, + QueryableMethod.SumSingleWithSelector, + ]; + + __takeOverloads = + [ + EnumerableMethod.Take, + EnumerableMethod.TakeWhile, // it's convenient to treat TakeWhile as if it was an overload + QueryableMethod.Take, + QueryableMethod.TakeWhile, + MongoQueryableMethod.TakeWithLong // it's convenient to group our custom Take method with the EnumerableOrQueryable Take methods + ]; + + __takeWhile = + [ + EnumerableMethod.TakeWhile, + QueryableMethod.TakeWhile + ]; + + __where = + [ + EnumerableMethod.Where, + QueryableMethod.Where, + ]; + + // initialize arrays of sets of methods after sets of methods + __firstOrLastOverloads = + [ + __firstOverloads, + __lastOverloads + ]; + + __firstOrLastWithPredicateOverloads = + [ + __firstWithPredicateOverloads, + __lastWithPredicateOverloads + ]; + + __firstOrLastOrSingleOverloads = + [ + __firstOverloads, + __lastOverloads, + __singleOverloads + ]; + + __firstOrLastOrSingleWithPredicateOverloads = + [ + __firstWithPredicateOverloads, + __lastWithPredicateOverloads, + __singleWithPredicateOverloads + ]; + + __maxOrMinOverloads = + [ + __maxOverloads, + __minOverloads + ]; + + __maxOrMinWithSelectorOverloads = + [ + __maxWithSelectorOverloads, + __minWithSelectorOverloads + ]; + + __skipOrTakeOverloads = + [ + __skipOverloads, + __takeOverloads + ]; + + __skipOrTakeWhile = + [ + __skipWhile, + __takeWhile + ]; + } + + public static HashSet AggregateOverloads => __aggregateOverloads; + public static HashSet AggregateWithFunc => __aggregateWithFunc; + public static HashSet AggregateWithSeedOverloads => __aggregateWithSeedOverloads; + public static HashSet AggregateWithSeedAndFunc => __aggregateWithSeedAndFunc; + public static HashSet AggregateWithSeedFuncAndResultSelector => __aggregateWithSeedFuncAndResultSelector; + public static HashSet All => __all; + public static HashSet AnyOverloads => __anyOverloads; + public static HashSet Any => __any; + public static HashSet AnyWithPredicate => __anyWithPredicate; + public static HashSet Append => __append; + public static HashSet AppendOrPrepend => __appendOrPrepend; + public static HashSet AverageOverloads => __averageOverloads; + public static HashSet AverageWithSelectorOverloads => __averageWithSelectorOverloads; + public static HashSet Concat => __concat; + public static HashSet CountOverloads => __countOverloads; + public static HashSet CountWithPredicateOverloads => __countWithPredicateOverloads; + public static HashSet Distinct => __distinct; + public static HashSet ElementAt => __elementAt; + public static HashSet ElementAtOrDefault => __elementAtOrDefault; + public static HashSet ElementAtOverloads => __elementAtOverloads; + public static HashSet Except => __except; + public static HashSet FirstOverloads => __firstOverloads; + public static HashSet FirstOrDefaultOverloads => __firstOrDefaultOverloads; + public static HashSet FirstWithPredicateOverloads => __firstWithPredicateOverloads; + public static HashSet GroupByOverloads => __groupByOverloads; + public static HashSet GroupByWithKeySelector => __groupByWithKeySelector; + public static HashSet GroupByWithKeySelectorAndElementSelector => __groupByWithKeySelectorAndElementSelector; + public static HashSet GroupByWithKeySelectorAndResultSelector => __groupByWithKeySelectorAndResultSelector; + public static HashSet GroupByWithKeySelectorElementSelectorAndResultSelector => __groupByWithKeySelectorElementSelectorAndResultSelector; + public static HashSet LastOverloads => __lastOverloads; + public static HashSet LastOrDefaultOverloads => __lastOrDefaultOverloads; + public static HashSet LastWithPredicateOverloads => __lastWithPredicateOverloads; + public static HashSet MaxOverloads => __maxOverloads; + public static HashSet MaxWithSelectorOverloads => __maxWithSelectorOverloads; + public static HashSet MinOverloads => __minOverloads; + public static HashSet MinWithSelectorOverloads => __minWithSelectorOverloads; + public static HashSet SelectManyOverloads => __selectManyOverloads; + public static HashSet SelectManyWithCollectionSelectorAndResultSelector => __selectManyWithCollectionSelectorAndResultSelector; + public static HashSet SelectManyWithSelector => __selectManyWithSelector; + public static HashSet SingleOverloads => __singleOverloads; + public static HashSet SingleWithPredicateOverloads => __singleWithPredicateOverloads; + public static HashSet SkipOverloads => __skipOverloads; + public static HashSet SkipWhile => __skipWhile; + public static HashSet SumOverloads => __sumOverloads; + public static HashSet SumWithSelectorOverloads => __sumWithSelectorOverloads; + + public static HashSet TakeOverloads => __takeOverloads; + public static HashSet TakeWhile => __takeWhile; + public static HashSet Where => __where; + + // arrays of sets of methods + public static HashSet[] FirstOrLastOverloads => __firstOrLastOverloads; + public static HashSet[] FirstOrLastWithPredicateOverloads => __firstOrLastWithPredicateOverloads; + public static HashSet[] FirstOrLastOrSingleOverloads => __firstOrLastOrSingleOverloads; + public static HashSet[] FirstOrLastOrSingleWithPredicateOverloads => __firstOrLastOrSingleWithPredicateOverloads; + public static HashSet[] MaxOrMinOverloads => __maxOrMinOverloads; + public static HashSet[] MaxOrMinWithSelectorOverloads => __maxOrMinWithSelectorOverloads; + public static HashSet[] SkipOrTakeOverloads => __skipOrTakeOverloads; + public static HashSet[] SkipOrTakeWhile => __skipOrTakeWhile; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableProperty.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableProperty.cs index 6e929c18d3c..845b0702a23 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableProperty.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableProperty.cs @@ -38,9 +38,9 @@ expression.Member is PropertyInfo propertyInfo && static bool ImplementsCollectionInterface(Type type) => - type.Implements(typeof(ICollection)) || - type.Implements(typeof(ICollection<>)) || - type.Implements(typeof(IReadOnlyCollection<>)); + type.ImplementsInterface(typeof(ICollection)) || + type.ImplementsInterface(typeof(ICollection<>)) || + type.ImplementsInterface(typeof(IReadOnlyCollection<>)); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/HashSetConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/HashSetConstructor.cs new file mode 100644 index 00000000000..3b062091f53 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/HashSetConstructor.cs @@ -0,0 +1,41 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class HashSetConstructor + { + public static bool IsWithCollectionConstructor(ConstructorInfo constructor) + { + if (constructor != null) + { + var declaringType = constructor.DeclaringType; + var parameters = constructor.GetParameters(); + return + declaringType.IsConstructedGenericType && + declaringType.GetGenericTypeDefinition() == typeof(HashSet<>) && + parameters.Length == 1 && + parameters[0].ParameterType.ImplementsIEnumerable(out var itemType) && + itemType == declaringType.GenericTypeArguments[0]; + } + + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ISetMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ISetMethod.cs new file mode 100644 index 00000000000..38898ffd9ef --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ISetMethod.cs @@ -0,0 +1,41 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class ISetMethod + { + // public static methods + public static bool IsSetEqualsMethod(MethodInfo method) + { + // many types implement a SetEquals method but the declaringType should always implement ISet + var declaringType = method.DeclaringType; + return + declaringType.ImplementsISet(out var itemType) && + method.IsPublic && + !method.IsStatic && + method.ReturnType == typeof(bool) && + method.Name == "SetEquals" && + method.GetParameters() is var parameters && + parameters.Length == 1 && + parameters[0] is var otherParameter && + otherParameter.ParameterType == typeof(IEnumerable<>).MakeGenericType(itemType); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/KeyValuePairConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/KeyValuePairConstructor.cs new file mode 100644 index 00000000000..5ffa9e126f2 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/KeyValuePairConstructor.cs @@ -0,0 +1,44 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class KeyValuePairConstructor + { + public static bool IsWithKeyAndValueConstructor(ConstructorInfo constructor) + { + if (constructor != null) + { + var declaringType = constructor.DeclaringType; + var parameters = constructor.GetParameters(); + return + declaringType.IsConstructedGenericType && + declaringType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>) && + declaringType.GetGenericArguments() is var typeParameters && + typeParameters[0] is var keyType && + typeParameters[1] is var valueType && + parameters.Length == 2 && + parameters[0].ParameterType == keyType && + parameters[1].ParameterType == valueType; + } + + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ListConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ListConstructor.cs new file mode 100644 index 00000000000..21c731c7ceb --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ListConstructor.cs @@ -0,0 +1,41 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class ListConstructor + { + public static bool IsWithCollectionConstructor(ConstructorInfo constructor) + { + if (constructor != null) + { + var declaringType = constructor.DeclaringType; + var parameters = constructor.GetParameters(); + return + declaringType.IsConstructedGenericType && + declaringType.GetGenericTypeDefinition() == typeof(List<>) && + parameters.Length == 1 && + parameters[0].ParameterType.ImplementsIEnumerable(out var itemType) && + itemType == declaringType.GenericTypeArguments[0]; + } + + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MathMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MathMethod.cs index 28623c07a47..ad6aaba4923 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MathMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MathMethod.cs @@ -14,6 +14,7 @@ */ using System; +using System.Collections.Generic; using System.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection @@ -58,9 +59,20 @@ internal static class MathMethod private static readonly MethodInfo __truncateDecimal; private static readonly MethodInfo __truncateDouble; + // sets of methods + private static readonly HashSet __absOverloads; + private static readonly HashSet __logOverloads; + private static readonly HashSet __trigonometricMethods; + // static constructor static MathMethod() { +#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER + __acosh = ReflectionInfo.Method((double d) => Math.Acosh(d)); + __asinh = ReflectionInfo.Method((double d) => Math.Asinh(d)); + __atanh = ReflectionInfo.Method((double d) => Math.Atanh(d)); +#endif + __absDecimal = ReflectionInfo.Method((decimal value) => Math.Abs(value)); __absDouble = ReflectionInfo.Method((double value) => Math.Abs(value)); __absInt16 = ReflectionInfo.Method((short value) => Math.Abs(value)); @@ -69,18 +81,9 @@ static MathMethod() __absSByte = ReflectionInfo.Method((sbyte value) => Math.Abs(value)); __absSingle = ReflectionInfo.Method((float value) => Math.Abs(value)); __acos = ReflectionInfo.Method((double d) => Math.Acos(d)); -#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER - __acosh = ReflectionInfo.Method((double d) => Math.Acosh(d)); -#endif __asin = ReflectionInfo.Method((double d) => Math.Asin(d)); -#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER - __asinh = ReflectionInfo.Method((double d) => Math.Asinh(d)); -#endif __atan = ReflectionInfo.Method((double d) => Math.Atan(d)); __atan2 = ReflectionInfo.Method((double x, double y) => Math.Atan2(x, y)); -#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER - __atanh = ReflectionInfo.Method((double d) => Math.Atanh(d)); -#endif __ceilingWithDecimal = ReflectionInfo.Method((decimal d) => Math.Ceiling(d)); __ceilingWithDouble = ReflectionInfo.Method((double a) => Math.Ceiling(a)); __cos = ReflectionInfo.Method((double d) => Math.Cos(d)); @@ -103,6 +106,42 @@ static MathMethod() __tanh = ReflectionInfo.Method((double a) => Math.Tanh(a)); __truncateDecimal = ReflectionInfo.Method((decimal d) => Math.Truncate(d)); __truncateDouble = ReflectionInfo.Method((double d) => Math.Truncate(d)); + + // sets of methods + __absOverloads = + [ + __absDecimal, + __absDouble, + __absInt16, + __absInt32, + __absInt64, + __absSByte, + __absSingle + ]; + + __logOverloads = + [ + __log, + __log10, // it's convenient to treat Log10 as if it was an overload + __logWithNewBase + ]; + + __trigonometricMethods = + [ + __acos, + __acosh, + __asin, + __asinh, + __atan, + __atanh, + __atan2, + __cos, + __cosh, + __sin, + __sinh, + __tan, + __tanh + ]; } // public properties @@ -142,5 +181,10 @@ static MathMethod() public static MethodInfo Tanh => __tanh; public static MethodInfo TruncateDecimal => __truncateDecimal; public static MethodInfo TruncateDouble => __truncateDouble; + + // sets of methods + public static HashSet AbsOverloads => __absOverloads; + public static HashSet LogOverloads => __logOverloads; + public static HashSet TrigonometricMethods => __trigonometricMethods; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs index c10550024c3..4e2915fc956 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs @@ -65,8 +65,56 @@ internal static class MongoEnumerableMethod private static readonly MethodInfo __percentileNullableSingleWithSelector; private static readonly MethodInfo __percentileSingle; private static readonly MethodInfo __percentileSingleWithSelector; + private static readonly MethodInfo __standardDeviationPopulationDecimal; + private static readonly MethodInfo __standardDeviationPopulationDecimalWithSelector; + private static readonly MethodInfo __standardDeviationPopulationDouble; + private static readonly MethodInfo __standardDeviationPopulationDoubleWithSelector; + private static readonly MethodInfo __standardDeviationPopulationInt32; + private static readonly MethodInfo __standardDeviationPopulationInt32WithSelector; + private static readonly MethodInfo __standardDeviationPopulationInt64; + private static readonly MethodInfo __standardDeviationPopulationInt64WithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableDecimal; + private static readonly MethodInfo __standardDeviationPopulationNullableDecimalWithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableDouble; + private static readonly MethodInfo __standardDeviationPopulationNullableDoubleWithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableInt32; + private static readonly MethodInfo __standardDeviationPopulationNullableInt32WithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableInt64; + private static readonly MethodInfo __standardDeviationPopulationNullableInt64WithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableSingle; + private static readonly MethodInfo __standardDeviationPopulationNullableSingleWithSelector; + private static readonly MethodInfo __standardDeviationPopulationSingle; + private static readonly MethodInfo __standardDeviationPopulationSingleWithSelector; + private static readonly MethodInfo __standardDeviationSampleDecimal; + private static readonly MethodInfo __standardDeviationSampleDecimalWithSelector; + private static readonly MethodInfo __standardDeviationSampleDouble; + private static readonly MethodInfo __standardDeviationSampleDoubleWithSelector; + private static readonly MethodInfo __standardDeviationSampleInt32; + private static readonly MethodInfo __standardDeviationSampleInt32WithSelector; + private static readonly MethodInfo __standardDeviationSampleInt64; + private static readonly MethodInfo __standardDeviationSampleInt64WithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableDecimal; + private static readonly MethodInfo __standardDeviationSampleNullableDecimalWithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableDouble; + private static readonly MethodInfo __standardDeviationSampleNullableDoubleWithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableInt32; + private static readonly MethodInfo __standardDeviationSampleNullableInt32WithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableInt64; + private static readonly MethodInfo __standardDeviationSampleNullableInt64WithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableSingle; + private static readonly MethodInfo __standardDeviationSampleNullableSingleWithSelector; + private static readonly MethodInfo __standardDeviationSampleSingle; + private static readonly MethodInfo __standardDeviationSampleSingleWithSelector; private static readonly MethodInfo __whereWithLimit; + // sets of methods + private static readonly HashSet __medianOverloads; + private static readonly HashSet __medianWithSelectorOverloads; + private static readonly HashSet __percentileOverloads; + private static readonly HashSet __percentileWithSelectorOverloads; + private static readonly HashSet __standardDeviationOverloads; + private static readonly HashSet __standardDeviationWithSelectorOverloads; + // static constructor static MongoEnumerableMethod() { @@ -113,7 +161,192 @@ static MongoEnumerableMethod() __percentileNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); __percentileSingle = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); __percentileSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __standardDeviationPopulationDecimal = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationDouble = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationInt32 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationInt64 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableDecimal = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableDouble = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableInt32 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableInt64 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableSingle = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationSingle = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationSampleDecimal = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleDouble = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleInt32 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleInt64 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableDecimal = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableDouble = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableInt32 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableInt64 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableSingle = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleSingle = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); __whereWithLimit = ReflectionInfo.Method((IEnumerable source, Func predicate, int limit) => source.Where(predicate, limit)); + + // initialize sets of methods after individual methods + __medianOverloads = + [ + __medianDecimal, + __medianDecimalWithSelector, + __medianDouble, + __medianDoubleWithSelector, + __medianInt32, + __medianInt32WithSelector, + __medianInt64, + __medianInt64WithSelector, + __medianNullableDecimal, + __medianNullableDecimalWithSelector, + __medianNullableDouble, + __medianNullableDoubleWithSelector, + __medianNullableInt32, + __medianNullableInt32WithSelector, + __medianNullableInt64, + __medianNullableInt64WithSelector, + __medianNullableSingle, + __medianNullableSingleWithSelector, + __medianSingle, + __medianSingleWithSelector + ]; + + __medianWithSelectorOverloads = + [ + __medianDecimalWithSelector, + __medianDoubleWithSelector, + __medianInt32WithSelector, + __medianInt64WithSelector, + __medianNullableDecimalWithSelector, + __medianNullableDoubleWithSelector, + __medianNullableInt32WithSelector, + __medianNullableInt64WithSelector, + __medianNullableSingleWithSelector, + __medianSingleWithSelector + ]; + + __percentileOverloads = + [ + __percentileDecimal, + __percentileDecimalWithSelector, + __percentileDouble, + __percentileDoubleWithSelector, + __percentileInt32, + __percentileInt32WithSelector, + __percentileInt64, + __percentileInt64WithSelector, + __percentileNullableDecimal, + __percentileNullableDecimalWithSelector, + __percentileNullableDouble, + __percentileNullableDoubleWithSelector, + __percentileNullableInt32, + __percentileNullableInt32WithSelector, + __percentileNullableInt64, + __percentileNullableInt64WithSelector, + __percentileNullableSingle, + __percentileNullableSingleWithSelector, + __percentileSingle, + __percentileSingleWithSelector + ]; + + __percentileWithSelectorOverloads = + [ + __percentileDecimalWithSelector, + __percentileDoubleWithSelector, + __percentileInt32WithSelector, + __percentileInt64WithSelector, + __percentileNullableDecimalWithSelector, + __percentileNullableDoubleWithSelector, + __percentileNullableInt32WithSelector, + __percentileNullableInt64WithSelector, + __percentileNullableSingleWithSelector, + __percentileSingleWithSelector + ]; + + __standardDeviationOverloads = + [ + __standardDeviationPopulationDecimal, + __standardDeviationPopulationDecimalWithSelector, + __standardDeviationPopulationDouble, + __standardDeviationPopulationDoubleWithSelector, + __standardDeviationPopulationInt32, + __standardDeviationPopulationInt32WithSelector, + __standardDeviationPopulationInt64, + __standardDeviationPopulationInt64WithSelector, + __standardDeviationPopulationNullableDecimal, + __standardDeviationPopulationNullableDecimalWithSelector, + __standardDeviationPopulationNullableDouble, + __standardDeviationPopulationNullableDoubleWithSelector, + __standardDeviationPopulationNullableInt32, + __standardDeviationPopulationNullableInt32WithSelector, + __standardDeviationPopulationNullableInt64, + __standardDeviationPopulationNullableInt64WithSelector, + __standardDeviationPopulationNullableSingle, + __standardDeviationPopulationNullableSingleWithSelector, + __standardDeviationPopulationSingle, + __standardDeviationPopulationSingleWithSelector, + __standardDeviationSampleDecimal, + __standardDeviationSampleDecimalWithSelector, + __standardDeviationSampleDouble, + __standardDeviationSampleDoubleWithSelector, + __standardDeviationSampleInt32, + __standardDeviationSampleInt32WithSelector, + __standardDeviationSampleInt64, + __standardDeviationSampleInt64WithSelector, + __standardDeviationSampleNullableDecimal, + __standardDeviationSampleNullableDecimalWithSelector, + __standardDeviationSampleNullableDouble, + __standardDeviationSampleNullableDoubleWithSelector, + __standardDeviationSampleNullableInt32, + __standardDeviationSampleNullableInt32WithSelector, + __standardDeviationSampleNullableInt64, + __standardDeviationSampleNullableInt64WithSelector, + __standardDeviationSampleNullableSingle, + __standardDeviationSampleNullableSingleWithSelector, + __standardDeviationSampleSingle, + __standardDeviationSampleSingleWithSelector, + ]; + + __standardDeviationWithSelectorOverloads = + [ + __standardDeviationPopulationDecimalWithSelector, + __standardDeviationPopulationDoubleWithSelector, + __standardDeviationPopulationInt32WithSelector, + __standardDeviationPopulationInt64WithSelector, + __standardDeviationPopulationNullableDecimalWithSelector, + __standardDeviationPopulationNullableDoubleWithSelector, + __standardDeviationPopulationNullableInt32WithSelector, + __standardDeviationPopulationNullableInt64WithSelector, + __standardDeviationPopulationNullableSingleWithSelector, + __standardDeviationPopulationSingleWithSelector, + __standardDeviationSampleDecimalWithSelector, + __standardDeviationSampleDoubleWithSelector, + __standardDeviationSampleInt32WithSelector, + __standardDeviationSampleInt64WithSelector, + __standardDeviationSampleNullableDecimalWithSelector, + __standardDeviationSampleNullableDoubleWithSelector, + __standardDeviationSampleNullableInt32WithSelector, + __standardDeviationSampleNullableInt64WithSelector, + __standardDeviationSampleNullableSingleWithSelector, + __standardDeviationSampleSingleWithSelector, + ]; } // public properties @@ -160,6 +393,54 @@ static MongoEnumerableMethod() public static MethodInfo PercentileNullableSingleWithSelector => __percentileNullableSingleWithSelector; public static MethodInfo PercentileSingle => __percentileSingle; public static MethodInfo PercentileSingleWithSelector => __percentileSingleWithSelector; + public static MethodInfo StandardDeviationPopulationDecimal => __standardDeviationPopulationDecimal; + public static MethodInfo StandardDeviationPopulationDecimalWithSelector => __standardDeviationPopulationDecimalWithSelector; + public static MethodInfo StandardDeviationPopulationDouble => __standardDeviationPopulationDouble; + public static MethodInfo StandardDeviationPopulationDoubleWithSelector => __standardDeviationPopulationDoubleWithSelector; + public static MethodInfo StandardDeviationPopulationInt32 => __standardDeviationPopulationInt32; + public static MethodInfo StandardDeviationPopulationInt32WithSelector => __standardDeviationPopulationInt32WithSelector; + public static MethodInfo StandardDeviationPopulationInt64 => __standardDeviationPopulationInt64; + public static MethodInfo StandardDeviationPopulationInt64WithSelector => __standardDeviationPopulationInt64WithSelector; + public static MethodInfo StandardDeviationPopulationNullableDecimal => __standardDeviationPopulationNullableDecimal; + public static MethodInfo StandardDeviationPopulationNullableDecimalWithSelector => __standardDeviationPopulationNullableDecimalWithSelector; + public static MethodInfo StandardDeviationPopulationNullableDouble => __standardDeviationPopulationNullableDouble; + public static MethodInfo StandardDeviationPopulationNullableDoubleWithSelector => __standardDeviationPopulationNullableDoubleWithSelector; + public static MethodInfo StandardDeviationPopulationNullableInt32 => __standardDeviationPopulationNullableInt32; + public static MethodInfo StandardDeviationPopulationNullableInt32WithSelector => __standardDeviationPopulationNullableInt32WithSelector; + public static MethodInfo StandardDeviationPopulationNullableInt64 => __standardDeviationPopulationNullableInt64; + public static MethodInfo StandardDeviationPopulationNullableInt64WithSelector => __standardDeviationPopulationNullableInt64WithSelector; + public static MethodInfo StandardDeviationPopulationNullableSingle => __standardDeviationPopulationNullableSingle; + public static MethodInfo StandardDeviationPopulationNullableSingleWithSelector => __standardDeviationPopulationNullableSingleWithSelector; + public static MethodInfo StandardDeviationPopulationSingle => __standardDeviationPopulationSingle; + public static MethodInfo StandardDeviationPopulationSingleWithSelector => __standardDeviationPopulationSingleWithSelector; + public static MethodInfo StandardDeviationSampleDecimal => __standardDeviationSampleDecimal; + public static MethodInfo StandardDeviationSampleDecimalWithSelector => __standardDeviationSampleDecimalWithSelector; + public static MethodInfo StandardDeviationSampleDouble => __standardDeviationSampleDouble; + public static MethodInfo StandardDeviationSampleDoubleWithSelector => __standardDeviationSampleDoubleWithSelector; + public static MethodInfo StandardDeviationSampleInt32 => __standardDeviationSampleInt32; + public static MethodInfo StandardDeviationSampleInt32WithSelector => __standardDeviationSampleInt32WithSelector; + public static MethodInfo StandardDeviationSampleInt64 => __standardDeviationSampleInt64; + public static MethodInfo StandardDeviationSampleInt64WithSelector => __standardDeviationSampleInt64WithSelector; + public static MethodInfo StandardDeviationSampleNullableDecimal => __standardDeviationSampleNullableDecimal; + public static MethodInfo StandardDeviationSampleNullableDecimalWithSelector => __standardDeviationSampleNullableDecimalWithSelector; + public static MethodInfo StandardDeviationSampleNullableDouble => __standardDeviationSampleNullableDouble; + public static MethodInfo StandardDeviationSampleNullableDoubleWithSelector => __standardDeviationSampleNullableDoubleWithSelector; + public static MethodInfo StandardDeviationSampleNullableInt32 => __standardDeviationSampleNullableInt32; + public static MethodInfo StandardDeviationSampleNullableInt32WithSelector => __standardDeviationSampleNullableInt32WithSelector; + public static MethodInfo StandardDeviationSampleNullableInt64 => __standardDeviationSampleNullableInt64; + public static MethodInfo StandardDeviationSampleNullableInt64WithSelector => __standardDeviationSampleNullableInt64WithSelector; + public static MethodInfo StandardDeviationSampleNullableSingle => __standardDeviationSampleNullableSingle; + public static MethodInfo StandardDeviationSampleNullableSingleWithSelector => __standardDeviationSampleNullableSingleWithSelector; + public static MethodInfo StandardDeviationSampleSingle => __standardDeviationSampleSingle; + public static MethodInfo StandardDeviationSampleSingleWithSelector => __standardDeviationSampleSingleWithSelector; public static MethodInfo WhereWithLimit => __whereWithLimit; + + // sets of methods + public static HashSet MedianOverloads => __medianOverloads; + public static HashSet MedianWithSelectorOverloads => __medianWithSelectorOverloads; + public static HashSet PercentileOverloads => __percentileOverloads; + public static HashSet PercentileWithSelectorOverloads => __percentileWithSelectorOverloads; + public static HashSet StandardDeviationOverloads => __standardDeviationOverloads; + public static HashSet StandardDeviationWithSelectorOverloads => __standardDeviationWithSelectorOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoQueryableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoQueryableMethod.cs index 245c04c6733..9edccab3c57 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoQueryableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoQueryableMethod.cs @@ -179,6 +179,9 @@ internal static class MongoQueryableMethod private static readonly MethodInfo __sumSingleWithSelectorAsync; private static readonly MethodInfo __takeWithLong; + private static readonly HashSet __lookupOverloads; + private static readonly HashSet __skipOrTakeWithLong; + // static constructor static MongoQueryableMethod() { @@ -334,6 +337,22 @@ static MongoQueryableMethod() __sumSingleAsync = ReflectionInfo.Method((IQueryable source, CancellationToken cancellationToken) => source.SumAsync(cancellationToken)); __sumSingleWithSelectorAsync = ReflectionInfo.Method((IQueryable source, Expression> selector, CancellationToken cancellationToken) => source.SumAsync(selector, cancellationToken)); __takeWithLong = ReflectionInfo.Method((IQueryable source, long count) => source.Take(count)); + + __lookupOverloads = + [ + __lookupWithDocumentsAndLocalFieldAndForeignField, + __lookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline, + __lookupWithDocumentsAndPipeline, + __lookupWithFromAndLocalFieldAndForeignField, + __lookupWithFromAndLocalFieldAndForeignFieldAndPipeline, + __lookupWithFromAndPipeline + ]; + + __skipOrTakeWithLong = + [ + __skipWithLong, + __takeWithLong + ]; } // public properties @@ -489,5 +508,9 @@ static MongoQueryableMethod() public static MethodInfo SumSingleAsync => __sumSingleAsync; public static MethodInfo SumSingleWithSelectorAsync => __sumSingleWithSelectorAsync; public static MethodInfo TakeWithLong => __takeWithLong; + + // sets of methods + public static HashSet LookupOverloads => __lookupOverloads; + public static HashSet SkipOrTakeWithLong => __skipOrTakeWithLong; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MqlMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MqlMethod.cs index 4b82e4a545c..1e11284cb82 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MqlMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MqlMethod.cs @@ -14,6 +14,7 @@ */ using System; +using System.Collections.Generic; using System.Reflection; using MongoDB.Bson; using MongoDB.Bson.Serialization; @@ -36,6 +37,11 @@ internal static class MqlMethod private static readonly MethodInfo __isNullOrMissing; private static readonly MethodInfo __sigmoid; + // sets of methods + private static readonly HashSet __dateFromStringOverloads; + private static readonly HashSet __dateFromStringWithFormatOverloads; + private static readonly HashSet __dateFromStringWithTimezoneOverloads; + // static constructor static MqlMethod() { @@ -51,6 +57,28 @@ static MqlMethod() __isMissing = ReflectionInfo.Method((object field) => Mql.IsMissing(field)); __isNullOrMissing = ReflectionInfo.Method((object field) => Mql.IsNullOrMissing(field)); __sigmoid = ReflectionInfo.Method((double value) => Mql.Sigmoid(value)); + + // initialize sets of methods after individual methods + __dateFromStringOverloads = + [ + __dateFromString, + __dateFromStringWithFormat, + __dateFromStringWithFormatAndTimezone, + __dateFromStringWithFormatAndTimezoneAndOnErrorAndOnNull + ]; + + __dateFromStringWithFormatOverloads = + [ + __dateFromStringWithFormat, + __dateFromStringWithFormatAndTimezone, + __dateFromStringWithFormatAndTimezoneAndOnErrorAndOnNull + ]; + + __dateFromStringWithTimezoneOverloads = + [ + __dateFromStringWithFormatAndTimezone, + __dateFromStringWithFormatAndTimezoneAndOnErrorAndOnNull + ]; } // public properties @@ -66,5 +94,10 @@ static MqlMethod() public static MethodInfo IsMissing => __isMissing; public static MethodInfo IsNullOrMissing => __isNullOrMissing; public static MethodInfo Sigmoid => __sigmoid; + + // sets of methods + public static HashSet DateFromStringOverloads => __dateFromStringOverloads; + public static HashSet DateFromStringWithFormatOverloads => __dateFromStringWithFormatOverloads; + public static HashSet DateFromStringWithTimezoneOverloads => __dateFromStringWithTimezoneOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs index 17896da1313..b027077ff7d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs @@ -28,6 +28,7 @@ internal static class QueryableMethod private static readonly MethodInfo __aggregateWithSeedAndFunc; private static readonly MethodInfo __aggregateWithSeedFuncAndResultSelector; private static readonly MethodInfo __all; + private static readonly MethodInfo __allWithPredicate; private static readonly MethodInfo __any; private static readonly MethodInfo __anyWithPredicate; private static readonly MethodInfo __append; @@ -90,7 +91,7 @@ internal static class QueryableMethod private static readonly MethodInfo __prepend; private static readonly MethodInfo __reverse; private static readonly MethodInfo __select; - private static readonly MethodInfo __selectMany; + private static readonly MethodInfo __selectManyWithSelector; private static readonly MethodInfo __selectManyWithCollectionSelectorAndResultSelector; private static readonly MethodInfo __selectManyWithCollectionSelectorTakingIndexAndResultSelector; private static readonly MethodInfo __selectManyWithSelectorTakingIndex; @@ -138,6 +139,7 @@ static QueryableMethod() __aggregateWithSeedAndFunc = ReflectionInfo.Method((IQueryable source, object seed, Expression> func) => source.Aggregate(seed, func)); __aggregateWithSeedFuncAndResultSelector = ReflectionInfo.Method((IQueryable source, object seed, Expression> func, Expression> selector) => source.Aggregate(seed, func, selector)); __all = ReflectionInfo.Method((IQueryable source, Expression> predicate) => source.All(predicate)); + __allWithPredicate = ReflectionInfo.Method((IQueryable source, Expression> predicate) => source.All(predicate)); __any = ReflectionInfo.Method((IQueryable source) => source.Any()); __anyWithPredicate = ReflectionInfo.Method((IQueryable source, Expression> predicate) => source.Any(predicate)); __append = ReflectionInfo.Method((IQueryable source, object element) => source.Append(element)); @@ -200,7 +202,7 @@ static QueryableMethod() __prepend = ReflectionInfo.Method((IQueryable source, object element) => source.Prepend(element)); __reverse = ReflectionInfo.Method((IQueryable source) => source.Reverse()); __select = ReflectionInfo.Method((IQueryable source, Expression> selector) => source.Select(selector)); - __selectMany = ReflectionInfo.Method((IQueryable source, Expression>> selector) => source.SelectMany(selector)); + __selectManyWithSelector = ReflectionInfo.Method((IQueryable source, Expression>> selector) => source.SelectMany(selector)); __selectManyWithCollectionSelectorAndResultSelector = ReflectionInfo.Method((IQueryable source, Expression>> collectionSelector, Expression> resultSelector) => source.SelectMany(collectionSelector, resultSelector)); __selectManyWithCollectionSelectorTakingIndexAndResultSelector = ReflectionInfo.Method((IQueryable source, Expression>> collectionSelector, Expression> resultSelector) => source.SelectMany(collectionSelector, resultSelector)); __selectManyWithSelectorTakingIndex = ReflectionInfo.Method((IQueryable source, Expression>> selector) => source.SelectMany(selector)); @@ -247,6 +249,7 @@ static QueryableMethod() public static MethodInfo AggregateWithSeedAndFunc => __aggregateWithSeedAndFunc; public static MethodInfo AggregateWithSeedFuncAndResultSelector => __aggregateWithSeedFuncAndResultSelector; public static MethodInfo All => __all; + public static MethodInfo AllWithPredicate => __allWithPredicate; public static MethodInfo Any => __any; public static MethodInfo AnyWithPredicate => __anyWithPredicate; public static MethodInfo Append => __append; @@ -291,7 +294,7 @@ static QueryableMethod() public static MethodInfo GroupByWithKeySelectorAndResultSelector => __groupByWithKeySelectorAndResultSelector; public static MethodInfo GroupByWithKeySelectorElementSelectorAndResultSelector => __groupByWithKeySelectorElementSelectorAndResultSelector; public static MethodInfo GroupJoin => __groupJoin; - public static MethodInfo Interset => __intersect; + public static MethodInfo Intersect => __intersect; public static MethodInfo Join => __join; public static MethodInfo Last => __last; public static MethodInfo LastOrDefault => __lastOrDefault; @@ -309,7 +312,7 @@ static QueryableMethod() public static MethodInfo Prepend => __prepend; public static MethodInfo Reverse => __reverse; public static MethodInfo Select => __select; - public static MethodInfo SelectMany => __selectMany; + public static MethodInfo SelectManyWithSelector => __selectManyWithSelector; public static MethodInfo SelectManyWithCollectionSelectorAndResultSelector => __selectManyWithCollectionSelectorAndResultSelector; public static MethodInfo SelectManyWithCollectionSelectorTakingIndexAndResultSelector => __selectManyWithCollectionSelectorTakingIndexAndResultSelector; public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/StringMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/StringMethod.cs index 96177eab363..93325261f2b 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/StringMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/StringMethod.cs @@ -28,6 +28,8 @@ internal static class StringMethod private static readonly MethodInfo __anyStringInWithParams; private static readonly MethodInfo __anyStringNinWithEnumerable; private static readonly MethodInfo __anyStringNinWithParams; + private static readonly MethodInfo __compare; + private static readonly MethodInfo __compareWithIgnoreCase; private static readonly MethodInfo __concatWith1Object; private static readonly MethodInfo __concatWith2Objects; private static readonly MethodInfo __concatWith3Objects; @@ -44,6 +46,7 @@ internal static class StringMethod private static readonly MethodInfo __endsWithWithString; private static readonly MethodInfo __endsWithWithStringAndComparisonType; private static readonly MethodInfo __endsWithWithStringAndIgnoreCaseAndCulture; + private static readonly MethodInfo __equalsWithComparisonType; private static readonly MethodInfo __getChars; private static readonly MethodInfo __indexOfAny; private static readonly MethodInfo __indexOfAnyWithStartIndex; @@ -72,8 +75,7 @@ internal static class StringMethod private static readonly MethodInfo __startsWithWithString; private static readonly MethodInfo __startsWithWithStringAndComparisonType; private static readonly MethodInfo __startsWithWithStringAndIgnoreCaseAndCulture; - private static readonly MethodInfo __staticCompare; - private static readonly MethodInfo __staticCompareWithIgnoreCase; + private static readonly MethodInfo __staticEqualsWithComparisonType; private static readonly MethodInfo __stringInWithEnumerable; private static readonly MethodInfo __stringInWithParams; private static readonly MethodInfo __stringNinWithEnumerable; @@ -93,6 +95,25 @@ internal static class StringMethod private static readonly MethodInfo __trimStart; private static readonly MethodInfo __trimWithChars; + // sets of methods + private static readonly HashSet __compareOverloads; + private static readonly HashSet __concatOverloads; + private static readonly HashSet __containsOverloads; + private static readonly HashSet __endsWithOverloads; + private static readonly HashSet __indexOfOverloads; + private static readonly HashSet __indexOfBytesOverloads; + private static readonly HashSet __indexOfWithCountOverloads; + private static readonly HashSet __indexOfWithStartIndexOverloads; + private static readonly HashSet __indexOfWithStringComparisonOverloads; + private static readonly HashSet __splitOverloads; + private static readonly HashSet __startsWithOverloads; + private static readonly HashSet __toLowerOverloads; + private static readonly HashSet __toUpperOverloads; + + // arrays of sets of methods + private static readonly HashSet[] __endsWithOrStartsWithOverloads; + private static readonly HashSet[] __toLowerOrToUpperOverloads; + // static constructor static StringMethod() { @@ -114,6 +135,8 @@ static StringMethod() __anyStringInWithParams = ReflectionInfo.Method((IEnumerable s, StringOrRegularExpression[] values) => s.AnyStringIn(values)); __anyStringNinWithEnumerable = ReflectionInfo.Method((IEnumerable s, IEnumerable values) => s.AnyStringNin(values)); __anyStringNinWithParams = ReflectionInfo.Method((IEnumerable s, StringOrRegularExpression[] values) => s.AnyStringNin(values)); + __compare = ReflectionInfo.Method((string strA, string strB) => String.Compare(strA, strB)); + __compareWithIgnoreCase = ReflectionInfo.Method((string strA, string strB, bool ignoreCase) => String.Compare(strA, strB, ignoreCase)); __concatWith1Object = ReflectionInfo.Method((object arg) => string.Concat(arg)); __concatWith2Objects = ReflectionInfo.Method((object arg0, object arg1) => string.Concat(arg0, arg1)); __concatWith3Objects = ReflectionInfo.Method((object arg0, object arg1, object arg2) => string.Concat(arg0, arg1, arg2)); @@ -126,6 +149,7 @@ static StringMethod() __endsWithWithString = ReflectionInfo.Method((string s, string value) => s.EndsWith(value)); __endsWithWithStringAndComparisonType = ReflectionInfo.Method((string s, string value, StringComparison comparisonType) => s.EndsWith(value, comparisonType)); __endsWithWithStringAndIgnoreCaseAndCulture = ReflectionInfo.Method((string s, string value, bool ignoreCase, CultureInfo culture) => s.EndsWith(value, ignoreCase, culture)); + __equalsWithComparisonType = ReflectionInfo.Method((string s, string value, StringComparison comparisonType) => s.Equals(value, comparisonType)); __getChars = ReflectionInfo.Method((string s, int index) => s[index]); __indexOfAny = ReflectionInfo.Method((string s, char[] anyOf) => s.IndexOfAny(anyOf)); __indexOfAnyWithStartIndex = ReflectionInfo.Method((string s, char[] anyOf, int startIndex) => s.IndexOfAny(anyOf, startIndex)); @@ -153,8 +177,7 @@ static StringMethod() __startsWithWithString = ReflectionInfo.Method((string s, string value) => s.StartsWith(value)); __startsWithWithStringAndComparisonType = ReflectionInfo.Method((string s, string value, StringComparison comparisonType) => s.StartsWith(value, comparisonType)); __startsWithWithStringAndIgnoreCaseAndCulture = ReflectionInfo.Method((string s, string value, bool ignoreCase, CultureInfo culture) => s.StartsWith(value, ignoreCase, culture)); - __staticCompare = ReflectionInfo.Method((string strA, string strB) => String.Compare(strA, strB)); - __staticCompareWithIgnoreCase = ReflectionInfo.Method((string strA, string strB, bool ignoreCase) => String.Compare(strA, strB, ignoreCase)); + __staticEqualsWithComparisonType = ReflectionInfo.Method((string a, string b, StringComparison comparisonType) => string.Equals(a, b, comparisonType)); __stringInWithEnumerable = ReflectionInfo.Method((string s, IEnumerable values) => s.StringIn(values)); __stringInWithParams = ReflectionInfo.Method((string s, StringOrRegularExpression[] values) => s.StringIn(values)); __stringNinWithEnumerable = ReflectionInfo.Method((string s, IEnumerable values) => s.StringNin(values)); @@ -173,6 +196,139 @@ static StringMethod() __trimEnd = ReflectionInfo.Method((string s, char[] trimChars) => s.TrimEnd(trimChars)); __trimStart = ReflectionInfo.Method((string s, char[] trimChars) => s.TrimStart(trimChars)); __trimWithChars = ReflectionInfo.Method((string s, char[] trimChars) => s.Trim(trimChars)); + + // initialize sets of methods after individual methods + __compareOverloads = + [ + __compare, + __compareWithIgnoreCase + ]; + + __concatOverloads = + [ + __concatWith1Object, + __concatWith2Objects, + __concatWith2Strings, + __concatWith3Objects, + __concatWith3Strings, + __concatWith4Strings, + __concatWithObjectArray, + __concatWithStringArray + ]; + + __containsOverloads = + [ + __containsWithChar, + __containsWithCharAndComparisonType, + __containsWithString, + __containsWithStringAndComparisonType + ]; + + __endsWithOverloads = + [ + __endsWithWithChar, + __endsWithWithString, + __endsWithWithStringAndComparisonType, + __endsWithWithStringAndIgnoreCaseAndCulture, + ]; + + __indexOfOverloads = + [ + __indexOfAny, + __indexOfAnyWithStartIndex, + __indexOfAnyWithStartIndexAndCount, + __indexOfBytesWithValue, + __indexOfBytesWithValueAndStartIndex, + __indexOfBytesWithValueAndStartIndexAndCount, + __indexOfWithChar, + __indexOfWithCharAndStartIndex, + __indexOfWithCharAndStartIndexAndCount, + __indexOfWithString, + __indexOfWithStringAndComparisonType, + __indexOfWithStringAndStartIndex, + __indexOfWithStringAndStartIndexAndComparisonType, + __indexOfWithStringAndStartIndexAndCount, + __indexOfWithStringAndStartIndexAndCountAndComparisonType, + ]; + + __indexOfBytesOverloads = + [ + __indexOfBytesWithValue, + __indexOfBytesWithValueAndStartIndex, + __indexOfBytesWithValueAndStartIndexAndCount + ]; + + __indexOfWithStartIndexOverloads = + [ + __indexOfBytesWithValueAndStartIndex, + __indexOfBytesWithValueAndStartIndexAndCount, + __indexOfWithCharAndStartIndex, + __indexOfWithCharAndStartIndexAndCount, + __indexOfWithStringAndStartIndex, + __indexOfWithStringAndStartIndexAndCount, + __indexOfWithStringAndStartIndexAndComparisonType, + __indexOfWithStringAndStartIndexAndCountAndComparisonType + ]; + + __indexOfWithCountOverloads = + [ + __indexOfBytesWithValueAndStartIndexAndCount, + __indexOfWithCharAndStartIndexAndCount, + __indexOfWithStringAndStartIndexAndCount, + __indexOfWithStringAndStartIndexAndCountAndComparisonType + ]; + + __indexOfWithStringComparisonOverloads = + [ + __indexOfWithStringAndComparisonType, + __indexOfWithStringAndStartIndexAndComparisonType, + __indexOfWithStringAndStartIndexAndCountAndComparisonType + ]; + + __splitOverloads = + [ + __splitWithChars, + __splitWithCharsAndCount, + __splitWithCharsAndCountAndOptions, + __splitWithCharsAndOptions, + __splitWithStringsAndCountAndOptions, + __splitWithStringsAndOptions + ]; + + __startsWithOverloads = + [ + __startsWithWithChar, + __startsWithWithString, + __startsWithWithStringAndComparisonType, + __startsWithWithStringAndIgnoreCaseAndCulture + ]; + + __toLowerOverloads = + [ + __toLower, + __toLowerInvariant, + __toLowerWithCulture, + ]; + + __toUpperOverloads = + [ + __toUpper, + __toUpperInvariant, + __toUpperWithCulture, + ]; + + // initialize sets of methods after individual methods + __endsWithOrStartsWithOverloads = + [ + __endsWithOverloads, + __startsWithOverloads + ]; + + __toLowerOrToUpperOverloads = + [ + __toLowerOverloads, + __toUpperOverloads + ]; } // public properties @@ -180,6 +336,8 @@ static StringMethod() public static MethodInfo AnyStringInWithParams => __anyStringInWithParams; public static MethodInfo AnyStringNinWithEnumerable => __anyStringNinWithEnumerable; public static MethodInfo AnyStringNinWithParams => __anyStringNinWithParams; + public static MethodInfo Compare => __compare; + public static MethodInfo CompareWithIgnoreCase => __compareWithIgnoreCase; public static MethodInfo ConcatWith1Object => __concatWith1Object; public static MethodInfo ConcatWith2Objects => __concatWith2Objects; public static MethodInfo ConcatWith3Objects => __concatWith3Objects; @@ -196,6 +354,7 @@ static StringMethod() public static MethodInfo EndsWithWithString => __endsWithWithString; public static MethodInfo EndsWithWithStringAndComparisonType => __endsWithWithStringAndComparisonType; public static MethodInfo EndsWithWithStringAndIgnoreCaseAndCulture => __endsWithWithStringAndIgnoreCaseAndCulture; + public static MethodInfo EqualsWithComparisonType => __equalsWithComparisonType; public static MethodInfo GetChars => __getChars; public static MethodInfo IndexOfAny => __indexOfAny; public static MethodInfo IndexOfAnyWithStartIndex => __indexOfAnyWithStartIndex; @@ -224,8 +383,7 @@ static StringMethod() public static MethodInfo StartsWithWithString => __startsWithWithString; public static MethodInfo StartsWithWithStringAndComparisonType => __startsWithWithStringAndComparisonType; public static MethodInfo StartsWithWithStringAndIgnoreCaseAndCulture => __startsWithWithStringAndIgnoreCaseAndCulture; - public static MethodInfo StaticCompare => __staticCompare; - public static MethodInfo StaticCompareWithIgnoreCase => __staticCompareWithIgnoreCase; + public static MethodInfo StaticEqualsWithComparisonType => __staticEqualsWithComparisonType; public static MethodInfo StringInWithEnumerable => __stringInWithEnumerable; public static MethodInfo StringInWithParams => __stringInWithParams; public static MethodInfo StringNinWithEnumerable => __stringNinWithEnumerable; @@ -244,5 +402,24 @@ static StringMethod() public static MethodInfo TrimEnd => __trimEnd; public static MethodInfo TrimStart => __trimStart; public static MethodInfo TrimWithChars => __trimWithChars; + + // sets of methods + public static HashSet ConcatOverloads => __concatOverloads; + public static HashSet ContainsOverloads => __containsOverloads; + public static HashSet EndsWithOverloads => __endsWithOverloads; + public static HashSet IndexOfOverloads => __indexOfOverloads; + public static HashSet IndexOfBytesOverloads => __indexOfBytesOverloads; + public static HashSet IndexOfWithCountOverloads => __indexOfWithCountOverloads; + public static HashSet IndexOfWithStartIndexOverloads => __indexOfWithStartIndexOverloads; + public static HashSet IndexOfWithStringComparisonOverloads => __indexOfWithStringComparisonOverloads; + public static HashSet SplitOverloads => __splitOverloads; + public static HashSet CompareOverloads => __compareOverloads; + public static HashSet StartsWithOverloads => __startsWithOverloads; + public static HashSet ToLowerOverloads => __toLowerOverloads; + public static HashSet ToUpperOverloads => __toUpperOverloads; + + // arrays of sets of methods + public static HashSet[] EndsWithOrStartsWithOverloads => __endsWithOrStartsWithOverloads; + public static HashSet[] ToLowerOrToUpperOverloads => __toLowerOrToUpperOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleConstructor.cs new file mode 100644 index 00000000000..80f8502c060 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleConstructor.cs @@ -0,0 +1,29 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +internal static class TupleOrValueTupleConstructor +{ + public static bool IsTupleOrValueTupleConstructor(ConstructorInfo constructor) + { + return + constructor != null && + constructor.DeclaringType.IsTupleOrValueTuple(); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleMethod.cs new file mode 100644 index 00000000000..f62af48a601 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleMethod.cs @@ -0,0 +1,49 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +internal static class TupleOrValueTupleMethod +{ + private static HashSet __createOverloads; + + static TupleOrValueTupleMethod() + { + __createOverloads = + [ + TupleMethod.Create1, + TupleMethod.Create2, + TupleMethod.Create3, + TupleMethod.Create4, + TupleMethod.Create5, + TupleMethod.Create6, + TupleMethod.Create7, + TupleMethod.Create8, + ValueTupleMethod.Create1, + ValueTupleMethod.Create2, + ValueTupleMethod.Create3, + ValueTupleMethod.Create4, + ValueTupleMethod.Create5, + ValueTupleMethod.Create6, + ValueTupleMethod.Create7, + ValueTupleMethod.Create8 + ]; + } + + public static HashSet CreateOverloads => __createOverloads; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/WindowMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/WindowMethod.cs index 693e8762269..a48b257dbe4 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/WindowMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/WindowMethod.cs @@ -141,6 +141,8 @@ internal static class WindowMethod private static readonly MethodInfo __sumWithNullableSingle; private static readonly MethodInfo __sumWithSingle; + private static readonly HashSet __percentileOverloads; + // static constructor static WindowMethod() { @@ -262,6 +264,21 @@ static WindowMethod() __sumWithNullableInt64 = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Sum(selector, window)); __sumWithNullableSingle = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Sum(selector, window)); __sumWithSingle = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Sum(selector, window)); + + __percentileOverloads = + [ + __percentileWithDecimal, + __percentileWithDouble, + __percentileWithInt32, + __percentileWithInt64, + __percentileWithNullableDecimal, + __percentileWithNullableDouble, + __percentileWithNullableInt32, + __percentileWithNullableInt64, + __percentileWithNullableSingle, + __percentileWithSingle + ]; + } // public properties @@ -383,5 +400,7 @@ static WindowMethod() public static MethodInfo SumWithNullableInt64 => __sumWithNullableInt64; public static MethodInfo SumWithNullableSingle => __sumWithNullableSingle; public static MethodInfo SumWithSingle => __sumWithSingle; + + public static HashSet PercentileOverloads => __percentileOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/MissingSerializerFinder.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/MissingSerializerFinder.cs new file mode 100644 index 00000000000..fd98d43ce2a --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/MissingSerializerFinder.cs @@ -0,0 +1,62 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using ExpressionVisitor = System.Linq.Expressions.ExpressionVisitor; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal class MissingSerializerFinder : ExpressionVisitor +{ + public static Expression FindExpressionWithMissingSerializer(Expression expression, SerializerMap nodeSerializers) + { + var visitor = new MissingSerializerFinder(nodeSerializers); + visitor.Visit(expression); + return visitor._expressionWithMissingSerializer; + } + + private Expression _expressionWithMissingSerializer = null; + private readonly SerializerMap _nodeSerializers; + + public MissingSerializerFinder(SerializerMap nodeSerializers) + { + _nodeSerializers = nodeSerializers; + } + + public Expression ExpressionWithMissingSerializer => _expressionWithMissingSerializer; + + public override Expression Visit(Expression node) + { + if (_nodeSerializers.IsKnown(node, out var nodeSerializer)) + { + if (nodeSerializer is IIgnoreSubtreeSerializer or IUnknowableSerializer) + { + return node; // don't visit subtree + } + } + + base.Visit(node); + + if (_expressionWithMissingSerializer == null && + node != null && + _nodeSerializers.IsNotKnown(node)) + { + _expressionWithMissingSerializer = node; + } + + return node; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinder.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinder.cs new file mode 100644 index 00000000000..c12a8cb8759 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinder.cs @@ -0,0 +1,45 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal static class SerializerFinder +{ + public static void FindSerializers( + Expression expression, + ExpressionTranslationOptions translationOptions, + SerializerMap nodeSerializers) + { + var visitor = new SerializerFinderVisitor(translationOptions, nodeSerializers); + + do + { + visitor.StartPass(); + visitor.Visit(expression); + visitor.EndPass(); + } + while (visitor.IsMakingProgress); + + //#if DEBUG + var expressionWithMissingSerializer = MissingSerializerFinder.FindExpressionWithMissingSerializer(expression, nodeSerializers); + if (expressionWithMissingSerializer != null) + { + throw new ExpressionNotSupportedException(expressionWithMissingSerializer, because: "we were unable to determine which serializer to use for the result"); + } + //#endif + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderHelperMethods.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderHelperMethods.cs new file mode 100644 index 00000000000..3bfed3cee83 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderHelperMethods.cs @@ -0,0 +1,236 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using IOrderedEnumerableSerializer=MongoDB.Driver.Linq.Linq3Implementation.Serializers.IOrderedEnumerableSerializer; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + private void AddNodeSerializer(Expression node, IBsonSerializer serializer) => _nodeSerializers.AddSerializer(node, serializer); + + private bool AreAllKnown(IEnumerable nodes, out IReadOnlyList nodeSerializers) + { + var nodeSerializersList = new List(); + foreach (var node in nodes) + { + if (IsKnown(node, out var nodeSerializer)) + { + nodeSerializersList.Add(nodeSerializer); + } + else + { + nodeSerializers = null; + return false; + } + } + + nodeSerializers = nodeSerializersList; + return true; + } + + private bool IsAnyKnown(IEnumerable nodes, out IBsonSerializer nodeSerializer) + { + foreach (var node in nodes) + { + if (IsKnown(node, out var outSerializer)) + { + nodeSerializer = outSerializer; + return true; + } + } + + nodeSerializer = null; + return false; + } + + private bool IsAnyNotKnown(IEnumerable nodes) + { + return nodes.Any(IsNotKnown); + } + + IBsonSerializer CreateCollectionSerializerFromCollectionSerializer(Type collectionType, IBsonSerializer collectionSerializer) + { + if (collectionSerializer.ValueType == collectionType) + { + return collectionSerializer; + } + + if (collectionSerializer is IUnknowableSerializer) + { + return UnknowableSerializer.Create(collectionType); + } + + var itemSerializer = collectionSerializer.GetItemSerializer(); + return CreateCollectionSerializerFromItemSerializer(collectionType, itemSerializer); + } + + IBsonSerializer CreateCollectionSerializerFromItemSerializer(Type collectionType, IBsonSerializer itemSerializer) + { + if (itemSerializer is IUnknowableSerializer) + { + return UnknowableSerializer.Create(collectionType); + } + + return collectionType switch + { + _ when collectionType.IsArray => ArraySerializer.Create(itemSerializer), + _ when collectionType.IsConstructedGenericType && collectionType.GetGenericTypeDefinition() == typeof(IEnumerable<>) => IEnumerableSerializer.Create(itemSerializer), + _ when collectionType.IsConstructedGenericType && collectionType.GetGenericTypeDefinition() == typeof(IOrderedEnumerable<>) => IOrderedEnumerableSerializer.Create(itemSerializer), + _ when collectionType.IsConstructedGenericType && collectionType.GetGenericTypeDefinition() == typeof(IQueryable<>) => IQueryableSerializer.Create(itemSerializer), + _ => (BsonSerializer.LookupSerializer(collectionType) as IChildSerializerConfigurable)?.WithChildSerializer(itemSerializer) + }; + } + + private void DeduceBaseTypeAndDerivedTypeSerializers(Expression baseTypeExpression, Expression derivedTypeExpression) + { + IBsonSerializer baseTypeSerializer; + IBsonSerializer derivedTypeSerializer; + + if (IsNotKnown(baseTypeExpression) && IsKnown(derivedTypeExpression, out derivedTypeSerializer)) + { + baseTypeSerializer = derivedTypeSerializer.GetBaseTypeSerializer(baseTypeExpression.Type); + AddNodeSerializer(baseTypeExpression, baseTypeSerializer); + } + + if (IsNotKnown(derivedTypeExpression) && IsKnown(baseTypeExpression, out baseTypeSerializer)) + { + derivedTypeSerializer = baseTypeSerializer.GetDerivedTypeSerializer(baseTypeExpression.Type); + AddNodeSerializer(derivedTypeExpression, derivedTypeSerializer); + } + } + + private void DeduceBooleanSerializer(Expression node) + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, BooleanSerializer.Instance); + } + } + + private void DeduceCharSerializer(Expression node) + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, CharSerializer.Instance); + } + } + + private void DeduceCollectionAndCollectionSerializers(Expression collectionExpression1, Expression collectionExpression2) + { + IBsonSerializer collectionSerializer1; + IBsonSerializer collectionSerializer2; + + if (IsNotKnown(collectionExpression1) && IsKnown(collectionExpression2, out collectionSerializer2)) + { + collectionSerializer1 = CreateCollectionSerializerFromCollectionSerializer(collectionExpression1.Type, collectionSerializer2); + AddNodeSerializer(collectionExpression1, collectionSerializer1); + } + + if (IsNotKnown(collectionExpression2) && IsKnown(collectionExpression1, out collectionSerializer1)) + { + collectionSerializer2 = CreateCollectionSerializerFromCollectionSerializer(collectionExpression2.Type, collectionSerializer1); + AddNodeSerializer(collectionExpression2, collectionSerializer2); + } + } + + private void DeduceCollectionAndItemSerializers(Expression collectionExpression, Expression itemExpression) + { + DeduceItemAndCollectionSerializers(itemExpression, collectionExpression); + } + + private void DeduceItemAndCollectionSerializers(Expression itemExpression, Expression collectionExpression) + { + if (IsNotKnown(itemExpression) && IsItemSerializerKnown(collectionExpression, out var itemSerializer)) + { + AddNodeSerializer(itemExpression, itemSerializer); + } + + if (IsNotKnown(collectionExpression) && IsKnown(itemExpression, out itemSerializer)) + { + var collectionSerializer = CreateCollectionSerializerFromItemSerializer(collectionExpression.Type, itemSerializer); + if (collectionSerializer != null) + { + AddNodeSerializer(collectionExpression, collectionSerializer); + } + } + } + + private void DeduceSerializer(Expression node, IBsonSerializer serializer) + { + if (IsNotKnown(node) && serializer != null) + { + AddNodeSerializer(node, serializer); + } + } + + private void DeduceSerializers(Expression expression1, Expression expression2) + { + if (IsNotKnown(expression1) && IsKnown(expression2, out var expression2Serializer) && expression2Serializer.ValueType == expression1.Type) + { + AddNodeSerializer(expression1, expression2Serializer); + } + + if (IsNotKnown(expression2) && IsKnown(expression1, out var expression1Serializer)&& expression1Serializer.ValueType == expression2.Type) + { + AddNodeSerializer(expression2, expression1Serializer); + } + } + + private void DeduceStringSerializer(Expression node) + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, StringSerializer.Instance); + } + } + + private void DeduceUnknowableSerializer(Expression node) + { + if (IsNotKnown(node)) + { + var unknowableSerializer = UnknowableSerializer.Create(node.Type); + AddNodeSerializer(node, unknowableSerializer); + } + } + + private bool IsItemSerializerKnown(Expression node, out IBsonSerializer itemSerializer) + { + if (IsKnown(node, out var nodeSerializer) && + nodeSerializer is IBsonArraySerializer arraySerializer && + arraySerializer.TryGetItemSerializationInfo(out var itemSerializationInfo)) + { + itemSerializer = itemSerializationInfo.Serializer; + return true; + } + + itemSerializer = null; + return false; + } + + private bool IsKnown(Expression node) => _nodeSerializers.IsKnown(node); + + private bool IsKnown(Expression node, out IBsonSerializer nodeSerializer) => _nodeSerializers.IsKnown(node, out nodeSerializer); + + private bool IsNotKnown(Expression node) => _nodeSerializers.IsNotKnown(node); +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderNewExpressionSerializerCreator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderNewExpressionSerializerCreator.cs new file mode 100644 index 00000000000..15bc1ca0674 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderNewExpressionSerializerCreator.cs @@ -0,0 +1,201 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + public IBsonSerializer CreateNewExpressionSerializer( + Expression expression, + NewExpression newExpression, + IReadOnlyList bindings) + { + var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct + var constructorArguments = newExpression.Arguments; + var classMap = CreateClassMap(newExpression.Type, constructorInfo, out var creatorMap); + + if (constructorInfo != null && creatorMap != null) + { + var constructorParameters = constructorInfo.GetParameters(); + var creatorMapParameters = creatorMap.Arguments?.ToArray(); + if (constructorParameters.Length > 0) + { + if (creatorMapParameters == null) + { + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters."); + } + + if (creatorMapParameters.Length != constructorParameters.Length) + { + throw new ExpressionNotSupportedException(expression, because: $"the constructor has {constructorParameters} parameters but the creatorMap has {creatorMapParameters.Length} parameters."); + } + + for (var i = 0; i < creatorMapParameters.Length; i++) + { + var creatorMapParameter = creatorMapParameters[i]; + var constructorArgumentExpression = constructorArguments[i]; + if (!IsKnown(constructorArgumentExpression, out var constructorArgumentSerializer)) + { + return null; + } + var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter); + EnsureDefaultValue(memberMap); + var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, constructorArgumentSerializer); + memberMap.SetSerializer(memberSerializer); + } + } + } + + if (bindings != null) + { + foreach (var binding in bindings) + { + var memberAssignment = (MemberAssignment)binding; + var member = memberAssignment.Member; + var memberMap = FindMemberMap(expression, classMap, member.Name); + var valueExpression = memberAssignment.Expression; + if (!IsKnown(valueExpression, out var valueSerializer)) + { + return null; + } + var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueSerializer); + memberMap.SetSerializer(memberSerializer); + } + } + + classMap.Freeze(); + + var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(newExpression.Type); + return (IBsonSerializer)Activator.CreateInstance(serializerType, classMap); + } + + private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap) + { + BsonClassMap baseClassMap = null; + if (classType.BaseType != null) + { + baseClassMap = CreateClassMap(classType.BaseType, null, out _); + } + + var classMapType = typeof(BsonClassMap<>).MakeGenericType(classType); + var classMapConstructorInfo = classMapType.GetConstructor(new Type[] { typeof(BsonClassMap) }); + var classMap = (BsonClassMap)classMapConstructorInfo.Invoke(new object[] { baseClassMap }); + if (constructorInfo != null) + { + creatorMap = classMap.MapConstructor(constructorInfo); + } + else + { + creatorMap = null; + } + + classMap.AutoMap(); + classMap.IdMemberMap?.SetElementName("_id"); // normally happens when Freeze is called but we need it sooner here + + return classMap; + } + + private static IBsonSerializer CoerceSourceSerializerToMemberSerializer(BsonMemberMap memberMap, IBsonSerializer sourceSerializer) + { + var memberType = memberMap.MemberType; + var memberSerializer = memberMap.GetSerializer(); + var sourceType = sourceSerializer.ValueType; + + if (memberType != sourceType && + memberType.ImplementsIEnumerable(out var memberItemType) && + sourceType.ImplementsIEnumerable(out var sourceItemType) && + sourceItemType == memberItemType && + sourceSerializer is IBsonArraySerializer sourceArraySerializer && + sourceArraySerializer.TryGetItemSerializationInfo(out var sourceItemSerializationInfo) && + memberSerializer is IChildSerializerConfigurable memberChildSerializerConfigurable) + { + var sourceItemSerializer = sourceItemSerializationInfo.Serializer; + return memberChildSerializerConfigurable.WithChildSerializer(sourceItemSerializer); + } + + return sourceSerializer; + } + + private static BsonMemberMap EnsureMemberMap(Expression expression, BsonClassMap classMap, MemberInfo creatorMapParameter) + { + var declaringClassMap = classMap; + while (declaringClassMap.ClassType != creatorMapParameter.DeclaringType) + { + declaringClassMap = declaringClassMap.BaseClassMap; + + if (declaringClassMap == null) + { + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching property for constructor parameter: {creatorMapParameter.Name}"); + } + } + + foreach (var memberMap in declaringClassMap.DeclaredMemberMaps) + { + if (MemberMapMatchesCreatorMapParameter(memberMap, creatorMapParameter)) + { + return memberMap; + } + } + + return declaringClassMap.MapMember(creatorMapParameter); + + static bool MemberMapMatchesCreatorMapParameter(BsonMemberMap memberMap, MemberInfo creatorMapParameter) + { + var memberInfo = memberMap.MemberInfo; + return + memberInfo.MemberType == creatorMapParameter.MemberType && + memberInfo.Name.Equals(creatorMapParameter.Name, StringComparison.OrdinalIgnoreCase); + } + } + + private static void EnsureDefaultValue(BsonMemberMap memberMap) + { + if (memberMap.IsDefaultValueSpecified) + { + return; + } + + var defaultValue = memberMap.MemberType.IsValueType ? Activator.CreateInstance(memberMap.MemberType) : null; + memberMap.SetDefaultValue(defaultValue); + } + + private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) + { + foreach (var memberMap in classMap.DeclaredMemberMaps) + { + if (memberMap.MemberName == memberName) + { + return memberMap; + } + } + + if (classMap.BaseClassMap != null) + { + return FindMemberMap(expression, classMap.BaseClassMap, memberName); + } + + throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}"); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitBinary.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitBinary.cs new file mode 100644 index 00000000000..8deba5b6492 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitBinary.cs @@ -0,0 +1,173 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitBinary(BinaryExpression node) + { + base.VisitBinary(node); + + var @operator = node.NodeType; + var leftExpression = node.Left; + var rightExpression = node.Right; + + if (node.NodeType == ExpressionType.Add && node.Type == typeof(string)) + { + DeduceStringSerializer(node); + return node; + } + + if (IsSymmetricalBinaryOperator(@operator)) + { + // expr1 op expr2 => expr1: expr2Serializer or expr2: expr1Serializer + DeduceSerializers(leftExpression, rightExpression); + } + + if (@operator == ExpressionType.ArrayIndex) + { + if (IsNotKnown(node) && + IsKnown(leftExpression, out var leftSerializer)) + { + IBsonSerializer itemSerializer; + if (leftSerializer is IPolymorphicArraySerializer polymorphicArraySerializer) + { + var index = rightExpression.GetConstantValue(node); + itemSerializer = polymorphicArraySerializer.GetItemSerializer(index); + } + else + { + itemSerializer = leftSerializer.GetItemSerializer(); + } + + // expr[index] => node: itemSerializer + AddNodeSerializer(node, itemSerializer); + } + } + + if (@operator == ExpressionType.Coalesce) + { + if (IsNotKnown(node) && + IsKnown(leftExpression, out var leftSerializer)) + { + if (leftSerializer.ValueType == node.Type) + { + AddNodeSerializer(node, leftSerializer); + } + else if ( + leftSerializer is INullableSerializer nullableSerializer && + nullableSerializer.ValueSerializer is var nullableSerializerValueSerializer && + nullableSerializerValueSerializer.ValueType == node.Type) + { + AddNodeSerializer(node, nullableSerializerValueSerializer); + } + else + { + DeduceUnknowableSerializer(node); // coalesce will be executed client-side + } + } + } + + if (leftExpression.IsConvert(out var leftConvertOperand) && + rightExpression.IsConvert(out var rightConvertOperand) && + leftConvertOperand.Type == rightConvertOperand.Type) + { + DeduceSerializers(leftConvertOperand, rightConvertOperand); + } + + if (IsNotKnown(node)) + { + var resultSerializer = GetResultSerializer(node, @operator); + if (resultSerializer != null) + { + AddNodeSerializer(node, resultSerializer); + } + } + + return node; + + static IBsonSerializer GetResultSerializer(Expression node, ExpressionType @operator) + { + switch (@operator) + { + case ExpressionType.And: + case ExpressionType.ExclusiveOr: + case ExpressionType.Or: + switch (node.Type) + { + case Type t when t == typeof(bool): return BooleanSerializer.Instance; + case Type t when t == typeof(int): return Int32Serializer.Instance; + } + goto default; + + case ExpressionType.AndAlso: + case ExpressionType.Equal: + case ExpressionType.GreaterThan: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + case ExpressionType.NotEqual: + case ExpressionType.OrElse: + case ExpressionType.TypeEqual: + return BooleanSerializer.Instance; + + case ExpressionType.Add: + case ExpressionType.AddChecked: + case ExpressionType.Divide: + case ExpressionType.Modulo: + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + if (StandardSerializers.TryGetSerializer(node.Type, out var resultSerializer)) + { + return resultSerializer; + } + goto default; + + default: + return null; + } + } + + static bool IsSymmetricalBinaryOperator(ExpressionType @operator) + => @operator is + ExpressionType.Add or + ExpressionType.AddChecked or + ExpressionType.And or + ExpressionType.AndAlso or + ExpressionType.Coalesce or + ExpressionType.Divide or + ExpressionType.Equal or + ExpressionType.GreaterThan or + ExpressionType.GreaterThanOrEqual or + ExpressionType.Modulo or + ExpressionType.Multiply or + ExpressionType.MultiplyChecked or + ExpressionType.Or or + ExpressionType.OrElse or + ExpressionType.Subtract or + ExpressionType.SubtractChecked; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitConditional.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitConditional.cs new file mode 100644 index 00000000000..cdfbe59e81d --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitConditional.cs @@ -0,0 +1,40 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitConditional(ConditionalExpression node) + { + var ifTrueExpression = node.IfTrue; + var ifFalseExpression = node.IfFalse; + + DeduceConditionalSerializers(); + base.VisitConditional(node); + DeduceConditionalSerializers(); + + return node; + + void DeduceConditionalSerializers() + { + DeduceBaseTypeAndDerivedTypeSerializers(node, ifTrueExpression); + DeduceBaseTypeAndDerivedTypeSerializers(node, ifFalseExpression); + DeduceBaseTypeAndDerivedTypeSerializers(node, ifTrueExpression); // call a second time in case ifFalse is the only known serializer + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitConstant.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitConstant.cs new file mode 100644 index 00000000000..7aede943349 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitConstant.cs @@ -0,0 +1,41 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitConstant(ConstantExpression node) + { + if (IsNotKnown(node) && _useDefaultSerializerForConstants) + { + if (StandardSerializers.TryGetSerializer(node.Type, out var standardSerializer)) + { + AddNodeSerializer(node, standardSerializer); + } + else + { + var registeredSerializer = BsonSerializer.LookupSerializer(node.Type); // TODO: don't use static registry + AddNodeSerializer(node, registeredSerializer); + } + } + + return node; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitIndex.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitIndex.cs new file mode 100644 index 00000000000..9245fa024ad --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitIndex.cs @@ -0,0 +1,91 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitIndex(IndexExpression node) + { + base.VisitIndex(node); + + var collectionExpression = node.Object; + var indexer = node.Indexer; + var arguments = node.Arguments; + + if (IsBsonValueIndexer()) + { + DeduceSerializer(node, BsonValueSerializer.Instance); + } + else if (IsDictionaryIndexer()) + { + if (IsKnown(collectionExpression, out var collectionSerializer) && + collectionSerializer is IBsonDictionarySerializer dictionarySerializer) + { + var valueSerializer = dictionarySerializer.ValueSerializer; + DeduceSerializer(node, valueSerializer); + } + } + // check array indexer AFTER dictionary indexer + else if (IsCollectionIndexer()) + { + if (IsKnown(collectionExpression, out var collectionSerializer) && + collectionSerializer is IBsonArraySerializer arraySerializer) + { + var itemSerializer = arraySerializer.GetItemSerializer(); + DeduceSerializer(node, itemSerializer); + } + } + // handle generic cases? + + return node; + + bool IsCollectionIndexer() + { + return + arguments.Count == 1 && + arguments[0] is var index && + index.Type == typeof(int); + } + + bool IsBsonValueIndexer() + { + var declaringType = indexer.DeclaringType; + return + (declaringType == typeof(BsonValue) || declaringType.IsSubclassOf(typeof(BsonValue))) && + arguments.Count == 1 && + arguments[0] is var index && + (index.Type == typeof(int) || index.Type == typeof(string)); + } + + bool IsDictionaryIndexer() + { + return + collectionExpression.Type.ImplementsDictionaryInterface(out var keyType, out var valueType) && + arguments.Count == 1 && + arguments[0] is var indexExpression && + indexExpression.Type == keyType && + indexer.PropertyType == valueType; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitLambda.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitLambda.cs new file mode 100644 index 00000000000..df044fc4060 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitLambda.cs @@ -0,0 +1,33 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitLambda(Expression node) + { + if (IsNotKnown(node)) + { + var ignoreNodeSerializer = IgnoreNodeSerializer.Create(node.Type); + AddNodeSerializer(node, ignoreNodeSerializer); + } + + return base.VisitLambda(node); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitListInit.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitListInit.cs new file mode 100644 index 00000000000..dcac29e1792 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitListInit.cs @@ -0,0 +1,39 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitListInit(ListInitExpression node) + { + var newExpression = node.NewExpression; + var initializers = node.Initializers; + + DeduceListInitSerializers(); + base.VisitListInit(node); + DeduceListInitSerializers(); + + return node; + + void DeduceListInitSerializers() + { + // TODO: handle initializers? + DeduceSerializers(node, newExpression); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMember.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMember.cs new file mode 100644 index 00000000000..ecbef27af38 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMember.cs @@ -0,0 +1,234 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitMember(MemberExpression node) + { + IBsonSerializer containingSerializer; + var member = node.Member; + var declaringType = member.DeclaringType; + var memberName = member.Name; + + base.VisitMember(node); + + if (IsNotKnown(node)) + { + var containingExpression = node.Expression; + if (IsKnown(containingExpression, out containingSerializer)) + { + // TODO: are there are other cases that still need to be handled? + + var resultSerializer = node.Member switch + { + _ when declaringType == typeof(BsonValue) => GetBsonValuePropertySerializer(), + _ when IsCollectionCountOrLengthProperty() => GetCollectionCountOrLengthPropertySerializer(), + _ when declaringType == typeof(DateTime) => GetDateTimePropertySerializer(), + _ when declaringType.IsConstructedGenericType && declaringType.GetGenericTypeDefinition() == typeof(Dictionary<,>) => GetDictionaryPropertySerializer(), + _ when declaringType.IsConstructedGenericType && declaringType.GetGenericTypeDefinition() == typeof(IDictionary<,>) => GetIDictionaryPropertySerializer(), + _ when declaringType.IsNullable() => GetNullablePropertySerializer(), + _ when declaringType.IsTupleOrValueTuple() => GetTupleOrValueTuplePropertySerializer(), + _ => GetPropertySerializer() + }; + + AddNodeSerializer(node, resultSerializer); + } + } + + return node; + + IBsonSerializer GetBsonValuePropertySerializer() + { + return memberName switch + { + "AsBoolean" => BooleanSerializer.Instance, + "AsBsonArray" => BsonArraySerializer.Instance, + "AsBsonBinaryData" => BsonBinaryDataSerializer.Instance, + "AsBsonDateTime" => BsonDateTimeSerializer.Instance, + "AsBsonDocument" => BsonDocumentSerializer.Instance, + "AsBsonJavaScript" => BsonJavaScriptSerializer.Instance, + "AsBsonJavaScriptWithScope" => BsonJavaScriptWithScopeSerializer.Instance, + "AsBsonMaxKey" => BsonMaxKeySerializer.Instance, + "AsBsonMinKey" => BsonMinKeySerializer.Instance, + "AsBsonNull" => BsonNullSerializer.Instance, + "AsBsonRegularExpression" => BsonRegularExpressionSerializer.Instance, + "AsBsonSymbol" => BsonSymbolSerializer.Instance, + "AsBsonTimestamp" => BsonTimestampSerializer.Instance, + "AsBsonUndefined" => BsonUndefinedSerializer.Instance, + "AsBsonValue" => BsonValueSerializer.Instance, + "AsByteArray" => ByteArraySerializer.Instance, + "AsDecimal128" => Decimal128Serializer.Instance, + "AsDecimal" => DecimalSerializer.Instance, + "AsDouble" => DoubleSerializer.Instance, + "AsGuid" => GuidSerializer.StandardInstance, + "AsInt32" => Int32Serializer.Instance, + "AsInt64" => Int64Serializer.Instance, + "AsLocalTime" => DateTimeSerializer.LocalInstance, + "AsNullableBoolean" => NullableSerializer.NullableBooleanInstance, + "AsNullableDecimal128" => NullableSerializer.NullableDecimal128Instance, + "AsNullableDecimal" => NullableSerializer.NullableDecimalInstance, + "AsNullableDouble" => NullableSerializer.NullableDoubleInstance, + "AsNullableGuid" => NullableSerializer.NullableStandardGuidInstance, + "AsNullableInt32" => NullableSerializer.NullableInt32Instance, + "AsNullableInt64" => NullableSerializer.NullableInt64Instance, + "AsNullableLocalTime" => NullableSerializer.NullableLocalDateTimeInstance, + "AsNullableObjectId" => NullableSerializer.NullableObjectIdInstance, + "AsNullableUniversalTime" => NullableSerializer.NullableUtcDateTimeInstance, + "AsObjectId" => ObjectIdSerializer.Instance, + "AsRegex" => RegexSerializer.RegularExpressionInstance, + "AsString" => StringSerializer.Instance, + "AsUniversalTime" => DateTimeSerializer.UtcInstance, + // TODO: return UnknowableSerializer??? + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetCollectionCountOrLengthPropertySerializer() + { + return Int32Serializer.Instance; + } + + IBsonSerializer GetDateTimePropertySerializer() + { + return memberName switch + { + "Date" => DateTimeSerializer.Instance, + "Day" => Int32Serializer.Instance, + "DayOfWeek" => new EnumSerializer(BsonType.Int32), + "DayOfYear" => Int32Serializer.Instance, + "Hour" => Int32Serializer.Instance, + "Millisecond" => Int32Serializer.Instance, + "Minute" => Int32Serializer.Instance, + "Month" => Int32Serializer.Instance, + "Now" => DateTimeSerializer.Instance, + "Second" => Int32Serializer.Instance, + "Ticks" => Int64Serializer.Instance, + "TimeOfDay" => new TimeSpanSerializer(BsonType.Int64, TimeSpanUnits.Milliseconds), + "Today" => DateTimeSerializer.Instance, + "UtcNow" => DateTimeSerializer.Instance, + "Year" => Int32Serializer.Instance, + // TODO: return UnknowableSerializer??? + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetDictionaryPropertySerializer() + { + if (containingSerializer.Unwrapped() is not IBsonDictionarySerializer dictionarySerializer) + { + throw new ExpressionNotSupportedException(node, because: "DictionarySerializer does not implement IBsonDictionarySerializer"); + } + + var keySerializer = dictionarySerializer.KeySerializer; + var valueSerializer = dictionarySerializer.ValueSerializer; + + return memberName switch + { + "Keys" => DictionaryKeyCollectionSerializer.Create(keySerializer, valueSerializer), + "Values" => DictionaryValueCollectionSerializer.Create(keySerializer, valueSerializer), + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetIDictionaryPropertySerializer() + { + if (containingSerializer is not IBsonDictionarySerializer dictionarySerializer) + { + throw new ExpressionNotSupportedException(node, because: "IDictionarySerializer does not implement IBsonDictionarySerializer"); + } + + var keySerializer = dictionarySerializer.KeySerializer; + var valueSerializer = dictionarySerializer.ValueSerializer; + + return memberName switch + { + "Keys" => ICollectionSerializer.Create(keySerializer), + "Values" => ICollectionSerializer.Create(valueSerializer), + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetNullablePropertySerializer() + { + return memberName switch + { + "HasValue" => BooleanSerializer.Instance, + "Value" => (containingSerializer as INullableSerializer)?.ValueSerializer, + // TODO: return UnknowableSerializer??? + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetPropertySerializer() + { + if (containingSerializer is not IBsonDocumentSerializer documentSerializer) + { + // TODO: return UnknowableSerializer??? + throw new ExpressionNotSupportedException(node, because: $"serializer type {containingSerializer.GetType()} does not implement the {nameof(IBsonDocumentSerializer)} interface"); + } + + if (!documentSerializer.TryGetMemberSerializationInfo(memberName, out var memberSerializationInfo)) + { + // TODO: return UnknowableSerializer??? + throw new ExpressionNotSupportedException(node, because: $"serializer type {containingSerializer.GetType()} does not support a member named: {memberName}"); + } + + return memberSerializationInfo.Serializer; + } + + IBsonSerializer GetTupleOrValueTuplePropertySerializer() + { + if (containingSerializer is not IBsonTupleSerializer tupleSerializer) + { + throw new ExpressionNotSupportedException(node, because: $"serializer type {containingSerializer.GetType()} does not implement the {nameof(IBsonTupleSerializer)} interface"); + } + + return memberName switch + { + "Item1" => tupleSerializer.GetItemSerializer(1), + "Item2" => tupleSerializer.GetItemSerializer(2), + "Item3" => tupleSerializer.GetItemSerializer(3), + "Item4" => tupleSerializer.GetItemSerializer(4), + "Item5" => tupleSerializer.GetItemSerializer(5), + "Item6" => tupleSerializer.GetItemSerializer(6), + "Item7" => tupleSerializer.GetItemSerializer(7), + "Rest" => tupleSerializer.GetItemSerializer(8), + // TODO: return UnknowableSerializer??? + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + bool IsCollectionCountOrLengthProperty() + { + return + (declaringType.ImplementsInterface(typeof(IEnumerable)) || declaringType == typeof(BitArray)) && + node.Type == typeof(int) && + (member.Name == "Count" || member.Name == "Length"); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMemberInit.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMemberInit.cs new file mode 100644 index 00000000000..b90992d9be4 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMemberInit.cs @@ -0,0 +1,97 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitMemberInit(MemberInitExpression node) + { + if (IsKnown(node, out var nodeSerializer)) + { + var newExpression = node.NewExpression; + if (newExpression != null) + { + if (IsNotKnown(newExpression)) + { + AddNodeSerializer(newExpression, nodeSerializer); + } + } + + if (node.Bindings.Count > 0) + { + if (nodeSerializer is not IBsonDocumentSerializer documentSerializer) + { + throw new ExpressionNotSupportedException(node, because: $"serializer type {nodeSerializer.GetType()} does not implement IBsonDocumentSerializer interface"); + } + + foreach (var binding in node.Bindings) + { + if (binding is MemberAssignment memberAssignment) + { + if (IsNotKnown(memberAssignment.Expression)) + { + var member = memberAssignment.Member; + var memberName = member.Name; + if (!documentSerializer.TryGetMemberSerializationInfo(memberName, out var memberSerializationInfo)) + { + throw new ExpressionNotSupportedException(node, because: $"type {member.DeclaringType} does not have a member named: {memberName}"); + } + var expressionSerializer = memberSerializationInfo.Serializer; + + if (expressionSerializer.ValueType != memberAssignment.Expression.Type && + expressionSerializer.ValueType.IsAssignableFrom(memberAssignment.Expression.Type)) + { + expressionSerializer = expressionSerializer.GetDerivedTypeSerializer(memberAssignment.Expression.Type); + } + + // member = expression => expression: memberSerializer (or derivedTypeSerializer) + AddNodeSerializer(memberAssignment.Expression, expressionSerializer); + } + } + } + } + } + + base.VisitMemberInit(node); + + if (IsNotKnown(node)) + { + var resultSerializer = GetResultSerializer(); + if (resultSerializer != null) + { + AddNodeSerializer(node, resultSerializer); + } + } + + return node; + + IBsonSerializer GetResultSerializer() + { + if (node.Type == typeof(BsonDocument)) + { + return BsonDocumentSerializer.Instance; + } + var newExpression = node.NewExpression; + var bindings = node.Bindings; + return CreateNewExpressionSerializer(node, newExpression, bindings); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMethodCall.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMethodCall.cs new file mode 100644 index 00000000000..a24c5b2a20a --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMethodCall.cs @@ -0,0 +1,2548 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + private static HashSet[] __averageOrMedianOrPercentileOverloads = + [ + EnumerableOrQueryableMethod.AverageOverloads, + MongoEnumerableMethod.MedianOverloads, + MongoEnumerableMethod.PercentileOverloads, + WindowMethod.PercentileOverloads + ]; + + private static HashSet[] __averageOrMedianOrPercentileWithSelectorOverloads = + [ + EnumerableOrQueryableMethod.AverageWithSelectorOverloads, + MongoEnumerableMethod.MedianWithSelectorOverloads, + MongoEnumerableMethod.PercentileWithSelectorOverloads, + WindowMethod.PercentileOverloads + ]; + + private static readonly HashSet[] __whereOverloads = + [ + EnumerableOrQueryableMethod.Where, + [MongoEnumerableMethod.WhereWithLimit] + ]; + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + var method = node.Method; + var arguments = node.Arguments; + + DeduceMethodCallSerializers(); + base.VisitMethodCall(node); + DeduceMethodCallSerializers(); + + return node; + + void DeduceMethodCallSerializers() + { + switch (node.Method.Name) + { + case "Abs": DeduceAbsMethodSerializers(); break; + case "Add": DeduceAddMethodSerializers(); break; + case "AddDays": DeduceAddDaysMethodSerializers(); break; + case "AddHours": DeduceAddHoursMethodSerializers(); break; + case "AddMilliseconds": DeduceAddMillisecondsMethodSerializers(); break; + case "AddMinutes": DeduceAddMinutesMethodSerializers(); break; + case "AddMonths": DeduceAddMonthsMethodSerializers(); break; + case "AddQuarters": DeduceAddQuartersMethodSerializers(); break; + case "AddSeconds": DeduceAddSecondsMethodSerializers(); break; + case "AddTicks": DeduceAddTicksMethodSerializers(); break; + case "AddWeeks": DeduceAddWeeksMethodSerializers(); break; + case "AddYears": DeduceAddYearsMethodSerializers(); break; + case "Aggregate": DeduceAggregateMethodSerializers(); break; + case "All": DeduceAllMethodSerializers(); break; + case "Any": DeduceAnyMethodSerializers(); break; + case "AppendStage": DeduceAppendStageMethodSerializers(); break; + case "As": DeduceAsMethodSerializers(); break; + case "AsQueryable": DeduceAsQueryableMethodSerializers(); break; + case "Concat": DeduceConcatMethodSerializers(); break; + case "Constant": DeduceConstantMethodSerializers(); break; + case "Contains": DeduceContainsMethodSerializers(); break; + case "ContainsKey": DeduceContainsKeyMethodSerializers(); break; + case "ContainsValue": DeduceContainsValueMethodSerializers(); break; + case "Convert": DeduceConvertMethodSerializers(); break; + case "Create": DeduceCreateMethodSerializers(); break; + case "DefaultIfEmpty": DeduceDefaultIfEmptyMethodSerializers(); break; + case "DegreesToRadians": DeduceDegreesToRadiansMethodSerializers(); break; + case "Distinct": DeduceDistinctMethodSerializers(); break; + case "Documents": DeduceDocumentsMethodSerializers(); break; + case "Equals": DeduceEqualsMethodSerializers(); break; + case "Except": DeduceExceptMethodSerializers(); break; + case "Exists": DeduceExistsMethodSerializers(); break; + case "Exp": DeduceExpMethodSerializers(); break; + case "Field": DeduceFieldMethodSerializers(); break; + case "get_Item": DeduceGetItemMethodSerializers(); break; + case "get_Chars": DeduceGetCharsMethodSerializers(); break; + case "GroupBy": DeduceGroupByMethodSerializers(); break; + case "GroupJoin": DeduceGroupJoinMethodSerializers(); break; + case "Inject": DeduceInjectMethodSerializers(); break; + case "Intersect": DeduceIntersectMethodSerializers(); break; + case "IsMatch": DeduceIsMatchMethodSerializers(); break; + case "IsSubsetOf": DeduceIsSubsetOfMethodSerializers(); break; + case "Join": DeduceJoinMethodSerializers(); break; + case "Lookup": DeduceLookupMethodSerializers(); break; + case "OfType": DeduceOfTypeMethodSerializers(); break; + case "Parse": DeduceParseMethodSerializers(); break; + case "Pow": DeducePowMethodSerializers(); break; + case "RadiansToDegrees": DeduceRadiansToDegreesMethodSerializers(); break; + case "Range": DeduceRangeMethodSerializers(); break; + case "Repeat": DeduceRepeatMethodSerializers(); break; + case "Reverse": DeduceReverseMethodSerializers(); break; + case "Round": DeduceRoundMethodSerializers(); break; + case "Select": DeduceSelectMethodSerializers(); break; + case "SelectMany": DeduceSelectManySerializers(); break; + case "SequenceEqual": DeduceSequenceEqualMethodSerializers(); break; + case "SetEquals": DeduceSetEqualsMethodSerializers(); break; + case "SetWindowFields": DeduceSetWindowFieldsMethodSerializers(); break; + case "Shift": DeduceShiftMethodSerializers(); break; + case "Split": DeduceSplitMethodSerializers(); break; + case "Sqrt": DeduceSqrtMethodSerializers(); break; + case "StringIn": DeduceStringInMethodSerializers(); break; + case "StrLenBytes": DeduceStrLenBytesMethodSerializers(); break; + case "Subtract": DeduceSubtractMethodSerializers(); break; + case "Sum": DeduceSumMethodSerializers(); break; + case "ToArray": DeduceToArrayMethodSerializers(); break; + case "ToList": DeduceToListSerializers(); break; + case "ToString": DeduceToStringSerializers(); break; + case "Truncate": DeduceTruncateSerializers(); break; + case "Union": DeduceUnionSerializers(); break; + case "Week": DeduceWeekSerializers(); break; + case "Where": DeduceWhereSerializers(); break; + case "Zip": DeduceZipSerializers(); break; + + case "Acos": + case "Acosh": + case "Asin": + case "Asinh": + case "Atan": + case "Atanh": + case "Atan2": + case "Cos": + case "Cosh": + case "Sin": + case "Sinh": + case "Tan": + case "Tanh": + DeduceTrigonometricMethodSerializers(); + break; + + case "AllElements": + case "AllMatchingElements": + case "FirstMatchingElement": + DeduceMatchingElementsMethodSerializers(); + break; + + case "Append": + case "Prepend": + DeduceAppendOrPrependMethodSerializers(); + break; + + case "Average": + case "Median": + case "Percentile": + DeduceAverageOrMedianOrPercentileMethodSerializers(); + break; + + case "Bottom": + case "BottomN": + case "FirstN": + case "LastN": + case "MaxN": + case "MinN": + case "Top": + case "TopN": + DeducePickMethodSerializers(); + break; + + case "Ceiling": + case "Floor": + DeduceCeilingOrFloorMethodSerializers(); + break; + + case "Compare": + case "CompareTo": + DeduceCompareOrCompareToMethodSerializers(); + break; + + case "Count": + case "LongCount": + DeduceCountMethodSerializers(); + break; + + case "ElementAt": + case "ElementAtOrDefault": + DeduceElementAtMethodSerializers(); + break; + + case "EndsWith": + case "StartsWith": + DeduceEndsWithOrStartsWithMethodSerializers(); + break; + + case "First": + case "FirstOrDefault": + case "Last": + case "LastOrDefault": + case "Single": + case "SingleOrDefault": + DeduceFirstOrLastOrSingleMethodsSerializers(); + break; + + case "IndexOf": + case "IndexOfBytes": + DeduceIndexOfMethodSerializers(); + break; + + case "IsMissing": + case "IsNullOrMissing": + DeduceIsMissingOrIsNullOrMissingMethodSerializers(); + break; + + case "IsNullOrEmpty": + case "IsNullOrWhiteSpace": + DeduceIsNullOrEmptyOrIsNullOrWhiteSpaceMethodSerializers(); + break; + + case "Ln": + case "Log": + case "Log10": + DeduceLogMethodSerializers(); + break; + + case "Max": + case "Min": + DeduceMaxOrMinMethodSerializers(); + break; + + case "OrderBy": + case "OrderByDescending": + case "ThenBy": + case "ThenByDescending": + DeduceOrderByMethodSerializers(); + break; + + case "Skip": + case "SkipWhile": + case "Take": + case "TakeWhile": + DeduceSkipOrTakeMethodSerializers(); + break; + + case "StandardDeviationPopulation": + case "StandardDeviationSample": + DeduceStandardDeviationMethodSerializers(); + break; + + case "Substring": + case "SubstrBytes": + DeduceSubstringMethodSerializers(); + break; + + case "ToLower": + case "ToLowerInvariant": + case "ToUpper": + case "ToUpperInvariant": + DeduceToLowerOrToUpperSerializers(); + break; + + default: + DeduceUnknownMethodSerializer(); + break; + } + } + + void DeduceAbsMethodSerializers() + { + if (method.IsOneOf(MathMethod.AbsOverloads)) + { + var valueExpression = arguments[0]; + DeduceSerializers(node, valueExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.Add, DateTimeMethod.AddWithTimezone, DateTimeMethod.AddWithUnit, DateTimeMethod.AddWithUnitAndTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddDaysMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddDays, DateTimeMethod.AddDaysWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddHoursMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddHours, DateTimeMethod.AddHoursWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddMillisecondsMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddMilliseconds, DateTimeMethod.AddMillisecondsWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddMinutesMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddMinutes, DateTimeMethod.AddMinutesWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddMonthsMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddMonths, DateTimeMethod.AddMonthsWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddQuartersMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddQuarters, DateTimeMethod.AddQuartersWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddSecondsMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddSeconds, DateTimeMethod.AddSecondsWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddTicksMethodSerializers() + { + if (method.Is(DateTimeMethod.AddTicks)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddWeeksMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddWeeks, DateTimeMethod.AddWeeksWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddYearsMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddYears, DateTimeMethod.AddYearsWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAggregateMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateOverloads)) + { + var sourceExpression = arguments[0]; + _ = IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer); + + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateWithFunc)) + { + var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var funcAccumulatorParameter = funcLambda.Parameters[0]; + var funcSourceItemParameter = funcLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(funcAccumulatorParameter, sourceExpression); + DeduceItemAndCollectionSerializers(funcSourceItemParameter, sourceExpression); + DeduceItemAndCollectionSerializers(funcLambda.Body, sourceExpression); + DeduceSerializers(node, funcLambda.Body); + } + + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateWithSeedAndFunc)) + { + var seedExpression = arguments[1]; + var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var funcAccumulatorParameter = funcLambda.Parameters[0]; + var funcSourceItemParameter = funcLambda.Parameters[1]; + + DeduceSerializers(seedExpression, funcLambda.Body); + DeduceSerializers(funcAccumulatorParameter, funcLambda.Body); + DeduceItemAndCollectionSerializers(funcSourceItemParameter, sourceExpression); + DeduceSerializers(node, funcLambda.Body); + } + + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateWithSeedFuncAndResultSelector)) + { + var seedExpression = arguments[1]; + var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var funcAccumulatorParameter = funcLambda.Parameters[0]; + var funcSourceItemParameter = funcLambda.Parameters[1]; + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var resultSelectorAccumulatorParameter = resultSelectorLambda.Parameters[0]; + + DeduceSerializers(seedExpression, funcLambda.Body); + DeduceSerializers(funcAccumulatorParameter, funcLambda.Body); + DeduceItemAndCollectionSerializers(funcSourceItemParameter, sourceExpression); + DeduceSerializers(resultSelectorAccumulatorParameter, funcLambda.Body); + DeduceSerializers(node, resultSelectorLambda.Body); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAllMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.AllWithPredicate, QueryableMethod.AllWithPredicate)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAnyMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.AnyOverloads)) + { + if (method.IsOneOf(EnumerableOrQueryableMethod.AnyWithPredicate)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters[0]; + + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + } + + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAppendOrPrependMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.AppendOrPrepend)) + { + var sourceExpression = arguments[0]; + var elementExpression = arguments[1]; + + DeduceItemAndCollectionSerializers(elementExpression, sourceExpression); + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAsMethodSerializers() + { + if (method.Is(MongoQueryableMethod.As)) + { + if (IsNotKnown(node)) + { + var resultSerializerExpression = arguments[1]; + if (resultSerializerExpression is not ConstantExpression resultSerializerConstantExpression) + { + throw new ExpressionNotSupportedException(node, because: "resultSerializer argument must be a constant"); + } + + var resultItemSerializer = (IBsonSerializer)resultSerializerConstantExpression.Value; + if (resultItemSerializer == null) + { + var resultItemType = method.GetGenericArguments()[1]; + resultItemSerializer = BsonSerializer.LookupSerializer(resultItemType); + } + + var resultSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultItemSerializer); + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAppendStageMethodSerializers() + { + if (method.Is(MongoQueryableMethod.AppendStage)) + { + if (IsNotKnown(node)) + { + var sourceExpression = arguments[0]; + var stageExpression = arguments[1]; + var resultSerializerExpression = arguments[2]; + + if (stageExpression is not ConstantExpression stageConstantExpression) + { + throw new ExpressionNotSupportedException(node, because: "stage argument must be a constant"); + } + var stageDefinition = (IPipelineStageDefinition)stageConstantExpression.Value; + + if (resultSerializerExpression is not ConstantExpression resultSerializerConstantExpression) + { + throw new ExpressionNotSupportedException(node, because: "resultSerializer argument must be a constant"); + } + var resultItemSerializer = (IBsonSerializer)resultSerializerConstantExpression.Value; + + if (resultItemSerializer == null && IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer)) + { + var serializerRegistry = BsonSerializer.SerializerRegistry; // TODO: get correct registry + var translationOptions = new ExpressionTranslationOptions(); // TODO: get correct translation options + var renderedStage = stageDefinition.Render(sourceItemSerializer, serializerRegistry, translationOptions); + resultItemSerializer = renderedStage.OutputSerializer; + } + + if (resultItemSerializer != null) + { + var resultSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultItemSerializer); + AddNodeSerializer(node, resultSerializer); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAsQueryableMethodSerializers() + { + if (method.Is(QueryableMethod.AsQueryable)) + { + var sourceExpression = arguments[0]; + + if (IsNotKnown(node) && IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer)) + { + var resultSerializer = NestedAsQueryableSerializer.Create(sourceItemSerializer); + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAverageOrMedianOrPercentileMethodSerializers() + { + if (method.IsOneOf(__averageOrMedianOrPercentileOverloads)) + { + if (method.IsOneOf(__averageOrMedianOrPercentileWithSelectorOverloads)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorSourceItemParameter = selectorLambda.Parameters[0]; + + DeduceItemAndCollectionSerializers(selectorSourceItemParameter, sourceExpression); + } + + if (IsNotKnown(node)) + { + var nodeSerializer = StandardSerializers.GetSerializer(node.Type); + AddNodeSerializer(node, nodeSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceCeilingOrFloorMethodSerializers() + { + if (method.IsOneOf(MathMethod.CeilingWithDecimal, MathMethod.CeilingWithDouble, MathMethod.FloorWithDecimal, MathMethod.FloorWithDouble)) + { + DeduceReturnsNumericSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceCompareOrCompareToMethodSerializers() + { + if (method.IsStaticCompareMethod() || + method.IsInstanceCompareToMethod() || + method.IsOneOf(StringMethod.CompareOverloads)) + { + var valueExpression = method.IsStatic ? arguments[0] : node.Object; + var comparandExpression = method.IsStatic ? arguments[1] : arguments[0]; + DeduceSerializers(valueExpression, comparandExpression); + DeduceReturnsInt32Serializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceConcatMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Concat, QueryableMethod.Concat)) + { + var firstExpression = arguments[0]; + var secondExpression = arguments[1]; + + DeduceCollectionAndCollectionSerializers(firstExpression, secondExpression); + DeduceCollectionAndCollectionSerializers(node, firstExpression); + } + else if (method.IsOneOf(StringMethod.ConcatOverloads)) + { + DeduceReturnsStringSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceConstantMethodSerializers() + { + if (method.IsOneOf(MqlMethod.ConstantWithRepresentation, MqlMethod.ConstantWithSerializer)) + { + var valueExpression = arguments[0]; + IBsonSerializer serializer = null; + + if (IsNotKnown(node) || IsNotKnown(valueExpression)) + { + if (method.Is(MqlMethod.ConstantWithRepresentation)) + { + var representationExpression = arguments[1]; + + var representation = representationExpression.GetConstantValue(node); + var defaultSerializer = BsonSerializer.LookupSerializer(valueExpression.Type); // TODO: don't use BsonSerializer + if (defaultSerializer is IRepresentationConfigurable representationConfigurableSerializer) + { + serializer = representationConfigurableSerializer.WithRepresentation(representation); + } + } + else if (method.Is(MqlMethod.ConstantWithSerializer)) + { + var serializerExpression = arguments[1]; + serializer = serializerExpression.GetConstantValue(node); + } + } + + DeduceSerializer(valueExpression, serializer); + DeduceSerializer(node, serializer); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceContainsKeyMethodSerializers() + { + if (IsDictionaryContainsKeyExpression(out var keyExpression)) + { + var dictionaryExpression = node.Object; + if (IsNotKnown(keyExpression) && IsKnown(dictionaryExpression, out var dictionarySerializer)) + { + var keySerializer = (dictionarySerializer as IBsonDictionarySerializer)?.KeySerializer; + AddNodeSerializer(keyExpression, keySerializer); + } + + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceContainsMethodSerializers() + { + if (method.IsOneOf(StringMethod.ContainsOverloads)) + { + DeduceReturnsBooleanSerializer(); + } + else if (EnumerableMethod.IsContainsMethod(node, out var collectionExpression, out var itemExpression)) + { + DeduceCollectionAndItemSerializers(collectionExpression, itemExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceContainsValueMethodSerializers() + { + if (IsContainsValueInstanceMethod(out var collectionExpression, out var valueExpression)) + { + if (IsNotKnown(valueExpression) && + IsKnown(collectionExpression, out var collectionSerializer)) + { + if (collectionSerializer is IBsonDictionarySerializer dictionarySerializer) + { + var valueSerializer = dictionarySerializer.ValueSerializer; + AddNodeSerializer(valueExpression, valueSerializer); + } + } + + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + + bool IsContainsValueInstanceMethod(out Expression collectionExpression, out Expression valueExpression) + { + if (method.IsPublic && + method.IsStatic == false && + method.ReturnType == typeof(bool) && + method.Name == "ContainsValue" && + method.GetParameters() is var parameters && + parameters.Length == 1) + { + collectionExpression = node.Object; + valueExpression = arguments[0]; + return true; + } + + collectionExpression = null; + valueExpression = null; + return false; + } + } + + void DeduceConvertMethodSerializers() + { + if (method.Is(MqlMethod.Convert)) + { + if (IsNotKnown(node)) + { + var toType = method.GetGenericArguments()[1]; + var resultSerializer = GetResultSerializer(node, toType); + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + + static IBsonSerializer GetResultSerializer(Expression expression, Type toType) + { + // TODO: should we use StandardSerializers at least for the subset of types where it would return the correct serializer? + + var isNullable = toType.IsNullable(); + var valueType = isNullable ? Nullable.GetUnderlyingType(toType) : toType; + + var valueSerializer = (IBsonSerializer)(Type.GetTypeCode(valueType) switch + { + TypeCode.Boolean => BooleanSerializer.Instance, + TypeCode.Byte => ByteSerializer.Instance, + TypeCode.Char => StringSerializer.Instance, + TypeCode.DateTime => DateTimeSerializer.Instance, + TypeCode.Decimal => DecimalSerializer.Instance, + TypeCode.Double => DoubleSerializer.Instance, + TypeCode.Int16 => Int16Serializer.Instance, + TypeCode.Int32 => Int32Serializer.Instance, + TypeCode.Int64 => Int64Serializer.Instance, + TypeCode.SByte => SByteSerializer.Instance, + TypeCode.Single => SingleSerializer.Instance, + TypeCode.String => StringSerializer.Instance, + TypeCode.UInt16 => UInt16Serializer.Instance, + TypeCode.UInt32 => Int32Serializer.Instance, + TypeCode.UInt64 => UInt64Serializer.Instance, + + _ when valueType == typeof(byte[]) => ByteArraySerializer.Instance, + _ when valueType == typeof(BsonBinaryData) => BsonBinaryDataSerializer.Instance, + _ when valueType == typeof(Decimal128) => Decimal128Serializer.Instance, + _ when valueType == typeof(Guid) => GuidSerializer.StandardInstance, + _ when valueType == typeof(ObjectId) => ObjectIdSerializer.Instance, + + _ => throw new ExpressionNotSupportedException(expression, because: $"{toType} is not a valid TTo for Convert") + }); + + return isNullable ? NullableSerializer.Create(valueSerializer) : valueSerializer; + } + } + + void DeduceCreateMethodSerializers() + { +#if NET6_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER + if (method.Is(KeyValuePairMethod.Create)) + { + if (IsAnyNotKnown(arguments) && IsKnown(node, out var nodeSerializer)) + { + var keyExpression = arguments[0]; + var valueExpression = arguments[1]; + + if (nodeSerializer.IsKeyValuePairSerializer(out _, out _, out var keySerializer, out var valueSerializer)) + { + DeduceSerializer(keyExpression, keySerializer); + DeduceSerializer(valueExpression, valueSerializer); + } + } + + if (IsNotKnown(node) && AreAllKnown(arguments, out var argumentSerializers)) + { + var keySerializer = argumentSerializers[0]; + var valueSerializer = argumentSerializers[1]; + var keyValuePairSerializer = KeyValuePairSerializer.Create(BsonType.Document, keySerializer, valueSerializer); + AddNodeSerializer(node, keyValuePairSerializer); + } + } + else + #endif + if (method.IsOneOf(TupleOrValueTupleMethod.CreateOverloads)) + { + if (IsAnyNotKnown(arguments) && IsKnown(node, out var nodeSerializer)) + { + if (nodeSerializer is IBsonTupleSerializer tupleSerializer) + { + for (var i = 1; i <= arguments.Count; i++) + { + var argumentExpression = arguments[i]; + if (IsNotKnown(argumentExpression)) + { + var itemSerializer = tupleSerializer.GetItemSerializer(i); + if (i == 8) + { + itemSerializer = (itemSerializer as IBsonTupleSerializer)?.GetItemSerializer(1); + } + AddNodeSerializer(argumentExpression, itemSerializer); + } + } + } + } + + if (IsNotKnown(node) && AreAllKnown(arguments, out var argumentSerializers)) + { + var tupleType = method.ReturnType; + + if (arguments.Count == 8) + { + var item8Expression = arguments[7]; + var item8Type = item8Expression.Type; + var item8Serializer = argumentSerializers[7]; + var restTupleType = (tupleType.IsTuple() ? typeof(Tuple<>) : typeof(ValueTuple<>)).MakeGenericType(item8Type); + var restSerializer = TupleOrValueTupleSerializer.Create(restTupleType, [item8Serializer]); + argumentSerializers = argumentSerializers.Take(7).Append(restSerializer).ToArray(); + } + + var tupleSerializer = TupleOrValueTupleSerializer.Create(tupleType, argumentSerializers); + AddNodeSerializer(node, tupleSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceCountMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.CountOverloads)) + { + if (method.IsOneOf(EnumerableOrQueryableMethod.CountWithPredicateOverloads)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + } + + DeduceReturnsNumericSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceDefaultIfEmptyMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.DefaultIfEmpty, EnumerableMethod.DefaultIfEmptyWithDefaultValue, QueryableMethod.DefaultIfEmpty, QueryableMethod.DefaultIfEmptyWithDefaultValue)) + { + var sourceExpression = arguments[0]; + + if (method.IsOneOf(EnumerableMethod.DefaultIfEmptyWithDefaultValue, QueryableMethod.DefaultIfEmptyWithDefaultValue)) + { + var defaultValueExpression = arguments[1]; + DeduceItemAndCollectionSerializers(defaultValueExpression, sourceExpression); + } + + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceDegreesToRadiansMethodSerializers() + { + if (method.Is(MongoDBMathMethod.DegreesToRadians)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceDistinctMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Distinct, QueryableMethod.Distinct)) + { + var sourceExpression = arguments[0]; + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceDocumentsMethodSerializers() + { + if (method.IsOneOf(MongoQueryableMethod.Documents, MongoQueryableMethod.DocumentsWithSerializer)) + { + if (IsNotKnown(node)) + { + IBsonSerializer documentSerializer; + if (method.Is(MongoQueryableMethod.DocumentsWithSerializer)) + { + var documentSerializerExpression = arguments[2]; + documentSerializer = documentSerializerExpression.GetConstantValue(node); + } + else + { + var documentsParameter = method.GetParameters()[1]; + var documentType = documentsParameter.ParameterType.GetElementType(); + documentSerializer = BsonSerializer.LookupSerializer(documentType); // TODO: don't use static registry + } + + var nodeSerializer = IQueryableSerializer.Create(documentSerializer); + AddNodeSerializer(node, nodeSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceElementAtMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.ElementAt, EnumerableMethod.ElementAtOrDefault, QueryableMethod.ElementAt, QueryableMethod.ElementAtOrDefault, QueryableMethod.ElementAtOrDefault)) + { + var sourceExpression = arguments[0]; + DeduceItemAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceEqualsMethodSerializers() + { + if (IsEqualsReturningBooleanMethod(out var expression1, out var expression2)) + { + DeduceSerializers(expression1, expression2); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + + bool IsEqualsReturningBooleanMethod(out Expression expression1, out Expression expression2) + { + if (method.Name == "Equals" && + method.ReturnType == typeof(bool) && + method.IsPublic) + { + if (method.IsStatic && + arguments.Count == 2) + { + expression1 = arguments[0]; + expression2 = arguments[1]; + return true; + } + + if (!method.IsStatic && + arguments.Count == 1) + { + expression1 = node.Object; + expression2 = arguments[0]; + return true; + } + + if (method.Is(StringMethod.EqualsWithComparisonType)) + { + expression1 = node.Object; + expression2 = arguments[0]; + return true; + } + + if (method.Is(StringMethod.StaticEqualsWithComparisonType)) + { + expression1 = arguments[0]; + expression2 = arguments[1]; + return true; + } + } + + expression1 = null; + expression2 = null; + return false; + } + } + + void DeduceExceptMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Except, QueryableMethod.Except)) + { + var firstExpression = arguments[0]; + var secondExpression = arguments[1]; + DeduceCollectionAndCollectionSerializers(secondExpression, firstExpression); + DeduceCollectionAndCollectionSerializers(node, firstExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceExistsMethodSerializers() + { + if (method.Is(ArrayMethod.Exists) || ListMethod.IsExistsMethod(method)) + { + var collectionExpression = method.IsStatic ? arguments[0] : node.Object; + var predicateExpression = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, method.IsStatic ? arguments[1] : arguments[0]); + var predicateParameter = predicateExpression.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, collectionExpression); + DeduceReturnsBooleanSerializer(); + } + else if (method.Is(MqlMethod.Exists)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceExpMethodSerializers() + { + if (method.IsOneOf(MathMethod.Exp)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceFieldMethodSerializers() + { + if (method.Is(MqlMethod.Field)) + { + if (IsNotKnown(node)) + { + var fieldSerializerExpression = arguments[2]; + var fieldSerializer = fieldSerializerExpression.GetConstantValue(node); + if (fieldSerializer == null) + { + throw new ExpressionNotSupportedException(node, because: "fieldSerializer is null"); + } + + AddNodeSerializer(node, fieldSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceFirstOrLastOrSingleMethodsSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.FirstOrLastOrSingleOverloads)) + { + if (method.IsOneOf(EnumerableOrQueryableMethod.FirstOrLastOrSingleWithPredicateOverloads)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + } + + DeduceReturnsOneSourceItemSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceGetItemMethodSerializers() + { + if (IsNotKnown(node)) + { + if (BsonValueMethod.IsGetItemWithIntMethod(method) || BsonValueMethod.IsGetItemWithStringMethod(method)) + { + AddNodeSerializer(node, BsonValueSerializer.Instance); + } + else if (IsInstanceGetItemMethod(out var collectionExpression, out var indexExpression)) + { + if (IsKnown(collectionExpression, out var collectionSerializer)) + { + if (collectionSerializer is IBsonArraySerializer arraySerializer && + indexExpression.Type == typeof(int) && + arraySerializer.GetItemSerializer() is var itemSerializer && + itemSerializer.ValueType == method.ReturnType) + { + AddNodeSerializer(node, itemSerializer); + } + else if ( + collectionSerializer is IBsonDictionarySerializer dictionarySerializer && + dictionarySerializer.KeySerializer is var keySerializer && + dictionarySerializer.ValueSerializer is var valueSerializer && + keySerializer.ValueType == indexExpression.Type && + valueSerializer.ValueType == method.ReturnType) + { + AddNodeSerializer(node, valueSerializer); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + bool IsInstanceGetItemMethod(out Expression collectionExpression, out Expression indexExpression) + { + if (method.IsStatic == false && + method.Name == "get_Item") + { + collectionExpression = node.Object; + indexExpression = arguments[0]; + return true; + } + + collectionExpression = null; + indexExpression = null; + return false; + } + } + + void DeduceGetCharsMethodSerializers() + { + if (method.Is(StringMethod.GetChars)) + { + DeduceCharSerializer(node); + } + + DeduceUnknowableSerializer(node); + } + + void DeduceGroupByMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.GroupByOverloads)) + { + var sourceExpression = arguments[0]; + var keySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var keySelectorParameter = keySelectorLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(keySelectorParameter, sourceExpression); + + if (method.IsOneOf(EnumerableOrQueryableMethod.GroupByWithKeySelector)) + { + if (IsNotKnown(node) && IsKnown(keySelectorLambda.Body, out var keySerializer) && IsItemSerializerKnown(sourceExpression, out var elementSerializer)) + { + var groupingSerializer = IGroupingSerializer.Create(keySerializer, elementSerializer); + var nodeSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, groupingSerializer); + AddNodeSerializer(node, nodeSerializer); + } + } + else if (method.IsOneOf(EnumerableOrQueryableMethod.GroupByWithKeySelectorAndElementSelector)) + { + var elementSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var elementSelectorParameter = elementSelectorLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(elementSelectorParameter, sourceExpression); + if (IsNotKnown(node) && IsKnown(keySelectorLambda.Body, out var keySerializer) && IsKnown(elementSelectorLambda.Body, out var elementSerializer)) + { + var groupingSerializer = IGroupingSerializer.Create(keySerializer, elementSerializer); + var nodeSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, groupingSerializer); + AddNodeSerializer(node, nodeSerializer); + } + } + else if (method.IsOneOf(EnumerableOrQueryableMethod.GroupByWithKeySelectorAndResultSelector)) + { + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var resultSelectorKeyParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorElementsParameter = resultSelectorLambda.Parameters[1]; + DeduceItemAndCollectionSerializers(keySelectorParameter, sourceExpression); + DeduceSerializers(resultSelectorKeyParameter, keySelectorLambda.Body); + DeduceCollectionAndCollectionSerializers(resultSelectorElementsParameter, sourceExpression); + DeduceResultSerializer(resultSelectorLambda.Body); + } + else if (method.IsOneOf(EnumerableOrQueryableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector)) + { + var elementSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var elementSelectorParameter = elementSelectorLambda.Parameters.Single(); + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var resultSelectorKeyParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorElementsParameter = resultSelectorLambda.Parameters[1]; + DeduceItemAndCollectionSerializers(keySelectorParameter, sourceExpression); + DeduceItemAndCollectionSerializers(elementSelectorParameter, sourceExpression); + DeduceSerializers(resultSelectorKeyParameter, keySelectorLambda.Body); + DeduceCollectionAndItemSerializers(resultSelectorElementsParameter, elementSelectorLambda.Body); + DeduceResultSerializer(resultSelectorLambda.Body); + } + + void DeduceResultSerializer(Expression resultExpression) + { + if (IsNotKnown(node) && IsKnown(resultExpression, out var resultSerializer)) + { + var nodeSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultSerializer); + AddNodeSerializer(node, nodeSerializer); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceGroupJoinMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.GroupJoin, QueryableMethod.GroupJoin)) + { + var outerExpression = arguments[0]; + var innerExpression = arguments[1]; + var outerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var outerKeySelectorItemParameter = outerKeySelectorLambda.Parameters.Single(); + var innerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var innerKeySelectorItemParameter = innerKeySelectorLambda.Parameters.Single(); + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]); + var resultSelectorOuterItemParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorInnerItemsParameter = resultSelectorLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(outerKeySelectorItemParameter, outerExpression); + DeduceItemAndCollectionSerializers(innerKeySelectorItemParameter, innerExpression); + DeduceItemAndCollectionSerializers(resultSelectorOuterItemParameter, outerExpression); + DeduceCollectionAndCollectionSerializers(resultSelectorInnerItemsParameter, innerExpression); + DeduceCollectionAndItemSerializers(node, resultSelectorLambda.Body); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIndexOfMethodSerializers() + { + if (method.IsOneOf(StringMethod.IndexOfOverloads)) + { + DeduceReturnsInt32Serializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceInjectMethodSerializers() + { + if (method.Is(LinqExtensionsMethod.Inject)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIntersectMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Intersect, QueryableMethod.Intersect)) + { + var sourceExpression = arguments[0]; + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIsMatchMethodSerializers() + { + if (method.Is(RegexMethod.StaticIsMatch)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIsMissingOrIsNullOrMissingMethodSerializers() + { + if (method.IsOneOf(MqlMethod.IsMissing, MqlMethod.IsNullOrMissing)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIsSubsetOfMethodSerializers() + { + if (IsSubsetOfMethod(method)) + { + var objectExpression = node.Object; + var otherExpression = arguments[0]; + + DeduceCollectionAndCollectionSerializers(objectExpression, otherExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + + static bool IsSubsetOfMethod(MethodInfo method) + { + var declaringType = method.DeclaringType; + var parameters = method.GetParameters(); + return + method.IsPublic && + method.IsStatic == false && + method.ReturnType == typeof(bool) && + method.Name == "IsSubsetOf" && + parameters.Length == 1 && + parameters[0] is var otherParameter && + declaringType.ImplementsIEnumerable(out var declaringTypeItemType) && + otherParameter.ParameterType.ImplementsIEnumerable(out var otherTypeItemType) && + otherTypeItemType == declaringTypeItemType; + } + } + + void DeduceJoinMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Join, QueryableMethod.Join)) + { + var outerExpression = arguments[0]; + var innerExpression = arguments[1]; + var outerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var outerKeySelectorItemParameter = outerKeySelectorLambda.Parameters.Single(); + var innerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var innerKeySelectorItemParameter = innerKeySelectorLambda.Parameters.Single(); + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]); + var resultSelectorOuterItemParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorInnerItemsParameter = resultSelectorLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(outerKeySelectorItemParameter, outerExpression); + DeduceItemAndCollectionSerializers(innerKeySelectorItemParameter, innerExpression); + DeduceItemAndCollectionSerializers(resultSelectorOuterItemParameter, outerExpression); + DeduceItemAndCollectionSerializers(resultSelectorInnerItemsParameter, innerExpression); + DeduceCollectionAndItemSerializers(node, resultSelectorLambda.Body); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIsNullOrEmptyOrIsNullOrWhiteSpaceMethodSerializers() + { + if (method.IsOneOf(StringMethod.IsNullOrEmpty, StringMethod.IsNullOrWhiteSpace)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceLogMethodSerializers() + { + if (method.IsOneOf(MathMethod.LogOverloads)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceLookupMethodSerializers() + { + if (method.IsOneOf(MongoQueryableMethod.LookupOverloads)) + { + var sourceExpression = arguments[0]; + + if (method.Is(MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignField)) + { + var documentsLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var documentsLambdaParameter = documentsLambda.Parameters.Single(); + var localFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var localFieldLambdaParameter = localFieldLambda.Parameters.Single(); + var foreignFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var foreignFieldLambdaParameter = foreignFieldLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(documentsLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(localFieldLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(foreignFieldLambdaParameter, documentsLambda.Body); + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(documentsLambda.Body, out var documentSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, documentSerializer); + AddNodeSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + else if (method.Is(MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline)) + { + var documentsLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var documentsLambdaParameter = documentsLambda.Parameters.Single(); + var localFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var localFieldLambdaParameter = localFieldLambda.Parameters.Single(); + var foreignFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var foreignFieldLambdaParameter = foreignFieldLambda.Parameters.Single(); + var pipelineLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]); + var pipelineLambdaLocalParameter = pipelineLambda.Parameters[0]; + var pipelineLambdaForeignQueryableParameter = pipelineLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(documentsLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(localFieldLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(foreignFieldLambdaParameter, documentsLambda.Body); + DeduceItemAndCollectionSerializers(pipelineLambdaLocalParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(pipelineLambdaForeignQueryableParameter, documentsLambda.Body); + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(pipelineLambda.Body, out var pipelineDocumentSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, pipelineDocumentSerializer); + AddNodeSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + else if (method.Is(MongoQueryableMethod.LookupWithDocumentsAndPipeline)) + { + var documentsLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var documentsLambdaParameter = documentsLambda.Parameters.Single(); + var pipelineLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var pipelineLambdaSourceParameter = pipelineLambda.Parameters[0]; + var pipelineLambdaQueryableDocumentParameter = pipelineLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(documentsLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(pipelineLambdaSourceParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(pipelineLambdaQueryableDocumentParameter, documentsLambda.Body); + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(pipelineLambda.Body, out var pipelineItemSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, pipelineItemSerializer); + AddNodeSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + + if (method.Is(MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignField)) + { + var fromExpression = arguments[1]; + var fromCollection = fromExpression.GetConstantValue(node); + var foreignDocumentSerializer = fromCollection.DocumentSerializer; + var localFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var localFieldLambdaParameter = localFieldLambda.Parameters.Single(); + var foreignFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var foreignFieldLambdaParameter = foreignFieldLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(localFieldLambdaParameter, sourceExpression); + DeduceSerializer(foreignFieldLambdaParameter, foreignDocumentSerializer); + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, foreignDocumentSerializer); + AddNodeSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + else if (method.Is(MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignFieldAndPipeline)) + { + var fromExpression = arguments[1]; + var fromCollection = fromExpression.GetConstantValue(node); + var foreignDocumentSerializer = fromCollection.DocumentSerializer; + var localFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var localFieldLambdaParameter = localFieldLambda.Parameters.Single(); + var foreignFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var foreignFieldLambdaParameter = foreignFieldLambda.Parameters.Single(); + var pipelineLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]); + var pipelineLambdaLocalParameter = pipelineLambda.Parameters[0]; + var pipelineLamdbaForeignQueryableParameter = pipelineLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(localFieldLambdaParameter, sourceExpression); + DeduceSerializer(foreignFieldLambdaParameter, foreignDocumentSerializer); + DeduceItemAndCollectionSerializers(pipelineLambdaLocalParameter, sourceExpression); + + if (IsNotKnown(pipelineLamdbaForeignQueryableParameter)) + { + var foreignQueryableSerializer = IQueryableSerializer.Create(foreignDocumentSerializer); + AddNodeSerializer(pipelineLamdbaForeignQueryableParameter, foreignQueryableSerializer); + } + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(pipelineLambda.Body, out var pipelineItemSerializer)) + { + var lookupResultsSerializer = LookupResultSerializer.Create(sourceItemSerializer, pipelineItemSerializer); + AddNodeSerializer(node, IQueryableSerializer.Create(lookupResultsSerializer)); + } + } + else if (method.Is(MongoQueryableMethod.LookupWithFromAndPipeline)) + { + var fromCollection = arguments[1].GetConstantValue(node); + var foreignDocumentSerializer = fromCollection.DocumentSerializer; + var pipelineLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var pipelineLambdaLocalParameter = pipelineLambda.Parameters[0]; + var pipelineLamdbaForeignQueryableParameter = pipelineLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(pipelineLambdaLocalParameter, sourceExpression); + + if (IsNotKnown(pipelineLamdbaForeignQueryableParameter)) + { + var foreignQueryableSerializer = IQueryableSerializer.Create(foreignDocumentSerializer); + AddNodeSerializer(pipelineLamdbaForeignQueryableParameter, foreignQueryableSerializer); + } + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(pipelineLambda.Body, out var pipelineItemSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, pipelineItemSerializer); + AddNodeSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceMatchingElementsMethodSerializers() + { + if (method.IsOneOf(MongoEnumerableMethod.AllElements, MongoEnumerableMethod.AllMatchingElements, MongoEnumerableMethod.FirstMatchingElement)) + { + DeduceReturnsOneSourceItemSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceMaxOrMinMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.MaxOrMinOverloads)) + { + if (method.IsOneOf(EnumerableOrQueryableMethod.MaxOrMinWithSelectorOverloads)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorItemParameter = selectorLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(selectorItemParameter, sourceExpression); + DeduceSerializers(node, selectorLambda.Body); + } + else + { + DeduceReturnsOneSourceItemSerializer(); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceOfTypeMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.OfType, QueryableMethod.OfType)) + { + var sourceExpression = arguments[0]; + var resultType = method.GetGenericArguments()[0]; + + if (IsNotKnown(node) && IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer)) + { + var resultItemSerializer = sourceItemSerializer.GetDerivedTypeSerializer(resultType); + var resultSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultItemSerializer); + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceOrderByMethodSerializers() + { + if (method.IsOneOf( + EnumerableMethod.OrderBy, + EnumerableMethod.OrderByDescending, + EnumerableMethod.ThenBy, + EnumerableMethod.ThenByDescending, + QueryableMethod.OrderBy, + QueryableMethod.OrderByDescending, + QueryableMethod.ThenBy, + QueryableMethod.ThenByDescending)) + { + var sourceExpression = arguments[0]; + var keySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var keySelectorParameter = keySelectorLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(keySelectorParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeducePickMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.PickOverloads)) + { + if (method.IsOneOf(EnumerableMethod.PickWithSortDefinitionOverloads)) + { + var sortByExpression = arguments[1]; + if (IsNotKnown(sortByExpression)) + { + var ignoreSubTreeSerializer = IgnoreSubtreeSerializer.Create(sortByExpression.Type); + AddNodeSerializer(sortByExpression, ignoreSubTreeSerializer); + } + } + + var sourceExpression = arguments[0]; + if (IsKnown(sourceExpression, out var sourceSerializer)) + { + var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer); + + var selectorExpression = arguments[method.IsOneOf(EnumerableMethod.PickWithSortDefinitionOverloads) ? 2 : 1]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, selectorExpression); + var selectorSourceItemParameter = selectorLambda.Parameters.Single(); + if (IsNotKnown(selectorSourceItemParameter)) + { + AddNodeSerializer(selectorSourceItemParameter, sourceItemSerializer); + } + } + + if (method.IsOneOf(EnumerableMethod.PickWithComputedNOverloads)) + { + var keyExpression = arguments[method.IsOneOf(EnumerableMethod.PickWithSortDefinitionOverloads) ? 3 : 2]; + if (IsKnown(keyExpression, out var keySerializer)) + { + var nExpression = arguments[method.IsOneOf(EnumerableMethod.PickWithSortDefinitionOverloads) ? 4 : 3]; + var nLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, nExpression); + var nLambdaKeyParameter = nLambda.Parameters.Single(); + + if (IsNotKnown(nLambdaKeyParameter)) + { + AddNodeSerializer(nLambdaKeyParameter, keySerializer); + } + } + } + + if (IsNotKnown(node)) + { + var selectorExpressionIndex = method switch + { + _ when method.Is(EnumerableMethod.Bottom) => 2, + _ when method.Is(EnumerableMethod.BottomN) => 2, + _ when method.Is(EnumerableMethod.BottomNWithComputedN) => 2, + _ when method.Is(EnumerableMethod.FirstN) => 1, + _ when method.Is(EnumerableMethod.FirstNWithComputedN) => 1, + _ when method.Is(EnumerableMethod.LastN) => 1, + _ when method.Is(EnumerableMethod.LastNWithComputedN) => 1, + _ when method.Is(EnumerableMethod.MaxN) => 1, + _ when method.Is(EnumerableMethod.MaxNWithComputedN) => 1, + _ when method.Is(EnumerableMethod.MinN) => 1, + _ when method.Is(EnumerableMethod.MinNWithComputedN) => 1, + _ when method.Is(EnumerableMethod.Top) => 2, + _ when method.Is(EnumerableMethod.TopN) => 2, + _ when method.Is(EnumerableMethod.TopNWithComputedN) => 2, + _ => throw new ArgumentException($"Unrecognized method: {method.Name}.") + }; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[selectorExpressionIndex]); + + if (IsKnown(selectorLambda.Body, out var selectorItemSerializer)) + { + var nodeSerializer = method.IsOneOf(EnumerableMethod.Bottom, EnumerableMethod.Top) ? + selectorItemSerializer : + IEnumerableSerializer.Create(selectorItemSerializer); + AddNodeSerializer(node, nodeSerializer); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceParseMethodSerializers() + { + if (IsNotKnown(node)) + { + if (IsParseMethod(method)) + { + var nodeSerializer = GetParseResultSerializer(method.DeclaringType); + AddNodeSerializer(node, nodeSerializer); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + static bool IsParseMethod(MethodInfo method) + { + var parameters = method.GetParameters(); + return + method.IsPublic && + method.IsStatic && + method.ReturnType == method.DeclaringType && + parameters.Length == 1 && + parameters[0].ParameterType == typeof(string); + } + + static IBsonSerializer GetParseResultSerializer(Type declaringType) + { + return declaringType switch + { + _ when declaringType == typeof(DateTime) => DateTimeSerializer.Instance, + _ when declaringType == typeof(decimal) => DecimalSerializer.Instance, + _ when declaringType == typeof(double) => DoubleSerializer.Instance, + _ when declaringType == typeof(int) => Int32Serializer.Instance, + _ when declaringType == typeof(short) => Int64Serializer.Instance, + _ => UnknowableSerializer.Create(declaringType) + }; + } + } + + void DeducePowMethodSerializers() + { + if (method.IsOneOf(MathMethod.Pow)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceRadiansToDegreesMethodSerializers() + { + if (method.Is(MongoDBMathMethod.RadiansToDegrees)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceReturnsBooleanSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, BooleanSerializer.Instance); + } + } + + void DeduceReturnsDateTimeSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, DateTimeSerializer.UtcInstance); + } + } + + void DeduceReturnsDecimalSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, DecimalSerializer.Instance); + } + } + + void DeduceReturnsDoubleSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, DoubleSerializer.Instance); + } + } + + void DeduceReturnsInt32Serializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, Int32Serializer.Instance); + } + } + + void DeduceReturnsInt64Serializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, Int64Serializer.Instance); + } + } + + void DeduceReturnsNullableDecimalSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, NullableSerializer.NullableDecimalInstance); + } + } + + void DeduceReturnsNullableDoubleSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, NullableSerializer.NullableDoubleInstance); + } + } + + void DeduceReturnsNullableInt32Serializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, NullableSerializer.NullableInt32Instance); + } + } + + void DeduceReturnsNullableInt64Serializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, NullableSerializer.NullableInt64Instance); + } + } + + void DeduceReturnsNullableSingleSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, NullableSerializer.NullableSingleInstance); + } + } + + void DeduceReturnsNumericSerializer() + { + if (IsNotKnown(node) && node.Type.IsNumeric()) + { + var numericSerializer = StandardSerializers.GetSerializer(node.Type); + AddNodeSerializer(node, numericSerializer); + } + } + + void DeduceReturnsNumericOrNullableNumericSerializer() + { + if (IsNotKnown(node) && node.Type.IsNumericOrNullableNumeric()) + { + var numericSerializer = StandardSerializers.GetSerializer(node.Type); + AddNodeSerializer(node, numericSerializer); + } + } + + void DeduceReturnsOneSourceItemSerializer() + { + var sourceExpression = arguments[0]; + + if (IsNotKnown(node) && IsKnown(sourceExpression, out var sourceSerializer)) + { + var nodeSerializer = sourceSerializer is IUnknowableSerializer ? + UnknowableSerializer.Create(node.Type) : + ArraySerializerHelper.GetItemSerializer(sourceSerializer); + AddNodeSerializer(node, nodeSerializer); + } + } + + void DeduceReturnsSingleSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, SingleSerializer.Instance); + } + } + + void DeduceReturnsStringSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, StringSerializer.Instance); + } + } + + void DeduceReturnsTimeSpanSerializer(TimeSpanUnits units) + { + if (IsNotKnown(node)) + { + var resultSerializer = new TimeSpanSerializer(BsonType.Int64, units); + AddNodeSerializer(node, resultSerializer); + } + } + + void DeduceRangeMethodSerializers() + { + if (method.Is(EnumerableMethod.Range)) + { + var elementExpression = arguments[0]; + DeduceCollectionAndItemSerializers(node, elementExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceRepeatMethodSerializers() + { + if (method.Is(EnumerableMethod.Repeat)) + { + var elementExpression = arguments[0]; + DeduceCollectionAndItemSerializers(node, elementExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceReverseMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Reverse, QueryableMethod.Reverse)) + { + var sourceExpression = arguments[0]; + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceRoundMethodSerializers() + { + if (method.IsOneOf(MathMethod.RoundWithDecimal, MathMethod.RoundWithDecimalAndDecimals, MathMethod.RoundWithDouble, MathMethod.RoundWithDoubleAndDigits)) + { + if (IsNotKnown(node)) + { + var resultSerializer = StandardSerializers.GetSerializer(node.Type); + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSelectMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Select, QueryableMethod.Select)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorParameter = selectorLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(selectorParameter, sourceExpression); + DeduceCollectionAndItemSerializers(node, selectorLambda.Body); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSelectManySerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.SelectManyOverloads)) + { + var sourceExpression = arguments[0]; + + if (method.IsOneOf(EnumerableOrQueryableMethod.SelectManyWithSelector)) + { + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorSourceParameter = selectorLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(selectorSourceParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(node, selectorLambda.Body); + } + + if (method.IsOneOf(EnumerableOrQueryableMethod.SelectManyWithCollectionSelectorAndResultSelector)) + { + var collectionSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + + var collectionSelectorSourceItemParameter = collectionSelectorLambda.Parameters.Single(); + var resultSelectorSourceItemParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorCollectionItemParameter = resultSelectorLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(collectionSelectorSourceItemParameter, sourceExpression); + DeduceItemAndCollectionSerializers(resultSelectorSourceItemParameter, sourceExpression); + DeduceItemAndCollectionSerializers(resultSelectorCollectionItemParameter, collectionSelectorLambda.Body); + DeduceCollectionAndItemSerializers(node, resultSelectorLambda.Body); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSequenceEqualMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.SequenceEqual, QueryableMethod.SequenceEqual)) + { + var source1Expression = arguments[0]; + var source2Expression = arguments[1]; + + DeduceCollectionAndCollectionSerializers(source1Expression, source2Expression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSetEqualsMethodSerializers() + { + if (ISetMethod.IsSetEqualsMethod(method)) + { + var objectExpression = node.Object; + var otherExpression = arguments[0]; + + DeduceCollectionAndCollectionSerializers(objectExpression, otherExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSetWindowFieldsMethodSerializers() + { + if (method.Is(EnumerableMethod.First)) + { + var objectExpression = node.Object; + var otherExpression = arguments[0]; + + DeduceCollectionAndCollectionSerializers(objectExpression, otherExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceShiftMethodSerializers() + { + if (method.IsOneOf(WindowMethod.Shift, WindowMethod.ShiftWithDefaultValue)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorSourceItemParameter = selectorLambda.Parameters[0]; + + DeduceItemAndCollectionSerializers(selectorSourceItemParameter, sourceExpression); + + if (IsNotKnown(node) && IsKnown(selectorLambda.Body, out var resultSerializer)) + { + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSplitMethodSerializers() + { + if (method.IsOneOf(StringMethod.SplitOverloads)) + { + if (IsNotKnown(node)) + { + var nodeSerializer = ArraySerializer.Create(StringSerializer.Instance); + AddNodeSerializer(node, nodeSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSqrtMethodSerializers() + { + if (method.Is(MathMethod.Sqrt)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceStandardDeviationMethodSerializers() + { + if (method.IsOneOf(MongoEnumerableMethod.StandardDeviationOverloads)) + { + if (method.IsOneOf(MongoEnumerableMethod.StandardDeviationWithSelectorOverloads)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorItemParameter = selectorLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(selectorItemParameter, sourceExpression); + } + + DeduceReturnsNumericOrNullableNumericSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceEndsWithOrStartsWithMethodSerializers() + { + if (method.IsOneOf(StringMethod.EndsWithOrStartsWithOverloads)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceStringInMethodSerializers() + { + if (method.IsOneOf(StringMethod.StringInWithEnumerable, StringMethod.StringInWithParams)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceStrLenBytesMethodSerializers() + { + if (method.Is(StringMethod.StrLenBytes)) + { + DeduceReturnsInt32Serializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSubstringMethodSerializers() + { + if (method.IsOneOf(StringMethod.Substring, StringMethod.SubstringWithLength, StringMethod.SubstrBytes)) + { + DeduceReturnsStringSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSubtractMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.SubtractReturningDateTimeOverloads)) + { + DeduceReturnsDateTimeSerializer(); + } + else if (method.IsOneOf(DateTimeMethod.SubtractReturningInt64Overloads)) + { + DeduceReturnsInt64Serializer(); + } + else if (method.IsOneOf(DateTimeMethod.SubtractReturningTimeSpanWithMillisecondsUnitsOverloads)) + { + var units = TimeSpanUnits.Milliseconds; + DeduceReturnsTimeSpanSerializer(units); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSumMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.SumOverloads)) + { + if (method.IsOneOf(EnumerableOrQueryableMethod.SumWithSelectorOverloads)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorParameter = selectorLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(selectorParameter, sourceExpression); + } + + var returnType = node.Type; + switch (returnType) + { + case not null when returnType == typeof(decimal): DeduceReturnsDecimalSerializer(); break; + case not null when returnType == typeof(double): DeduceReturnsDoubleSerializer(); break; + case not null when returnType == typeof(int): DeduceReturnsInt32Serializer(); break; + case not null when returnType == typeof(long): DeduceReturnsInt64Serializer(); break; + case not null when returnType == typeof(float): DeduceReturnsSingleSerializer(); break; + case not null when returnType == typeof(decimal?): DeduceReturnsNullableDecimalSerializer(); break; + case not null when returnType == typeof(double?): DeduceReturnsNullableDoubleSerializer(); break; + case not null when returnType == typeof(int?): DeduceReturnsNullableInt32Serializer(); break; + case not null when returnType == typeof(long?): DeduceReturnsNullableInt64Serializer(); break; + case not null when returnType == typeof(float?): DeduceReturnsNullableSingleSerializer(); break; + + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSkipOrTakeMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.SkipOrTakeOverloads)) + { + var sourceExpression = arguments[0]; + + if (method.IsOneOf(EnumerableOrQueryableMethod.SkipOrTakeWhile)) + { + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + } + + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceToArrayMethodSerializers() + { + if (IsToArrayMethod(out var sourceExpression)) + { + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + + bool IsToArrayMethod(out Expression sourceExpression) + { + if (method.IsPublic && + method.Name == "ToArray" && + method.GetParameters().Length == (method.IsStatic ? 1 : 0)) + { + sourceExpression = method.IsStatic ? arguments[0] : node.Object; + return true; + } + + sourceExpression = null; + return false; + } + } + + void DeduceToListSerializers() + { + if (IsNotKnown(node)) + { + var source = method.IsStatic ? arguments[0] : node.Object; + if (IsKnown(source, out var sourceSerializer)) + { + var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer); + var resultSerializer = ListSerializer.Create(sourceItemSerializer); + AddNodeSerializer(node, resultSerializer); + } + } + } + + void DeduceToLowerOrToUpperSerializers() + { + if (method.IsOneOf(StringMethod.ToLowerOrToUpperOverloads)) + { + DeduceReturnsStringSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceToStringSerializers() + { + DeduceReturnsStringSerializer(); + } + + void DeduceTrigonometricMethodSerializers() + { + if (method.IsOneOf(MathMethod.TrigonometricMethods)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceTruncateSerializers() + { + if (method.IsOneOf(DateTimeMethod.Truncate, DateTimeMethod.TruncateWithBinSize, DateTimeMethod.TruncateWithBinSizeAndTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else if (method.IsOneOf(MathMethod.TruncateDecimal, MathMethod.TruncateDouble)) + { + DeduceReturnsNumericSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceUnionSerializers() + { + if (method.IsOneOf(EnumerableMethod.Union, QueryableMethod.Union)) + { + var sourceExpression = arguments[0]; + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceUnknownMethodSerializer() + { + DeduceUnknowableSerializer(node); + } + + void DeduceWeekSerializers() + { + if (method.IsOneOf(DateTimeMethod.Week, DateTimeMethod.WeekWithTimezone)) + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, Int32Serializer.Instance); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceWhereSerializers() + { + if (method.IsOneOf(__whereOverloads)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceZipSerializers() + { + if (method.IsOneOf(EnumerableMethod.Zip, QueryableMethod.Zip)) + { + var firstExpression = arguments[0]; + var secondExpression = arguments[1]; + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var resultSelectorFirstParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorSecondParameter = resultSelectorLambda.Parameters[1]; + + if (IsNotKnown(resultSelectorFirstParameter) && IsKnown(firstExpression, out var firstSerializer)) + { + var firstItemSerializer = ArraySerializerHelper.GetItemSerializer(firstSerializer); + AddNodeSerializer(resultSelectorFirstParameter, firstItemSerializer); + } + + if (IsNotKnown(resultSelectorSecondParameter) && IsKnown(secondExpression, out var secondSerializer)) + { + var secondItemSerializer = ArraySerializerHelper.GetItemSerializer(secondSerializer); + AddNodeSerializer(resultSelectorSecondParameter, secondItemSerializer); + } + + if (IsNotKnown(node) && IsKnown(resultSelectorLambda.Body, out var resultItemSerializer)) + { + var resultSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultItemSerializer); + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + bool IsDictionaryContainsKeyExpression(out Expression keyExpression) + { + if (DictionaryMethod.IsContainsKeyMethod(method)) + { + keyExpression = arguments[0]; + return true; + } + + keyExpression = null; + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitNew.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitNew.cs new file mode 100644 index 00000000000..913ef139d13 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitNew.cs @@ -0,0 +1,140 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitNew(NewExpression node) + { + var constructor = node.Constructor; + var arguments = node.Arguments; + IBsonSerializer nodeSerializer; + + if (IsKnown(node, out nodeSerializer) && + arguments.Any(IsNotKnown)) + { + if (!typeof(BsonValue).IsAssignableFrom(node.Type) && + nodeSerializer is IBsonDocumentSerializer) + { + var matchingMemberSerializationInfos = nodeSerializer.GetMatchingMemberSerializationInfosForConstructorParameters(node, node.Constructor); + for (var i = 0; i < matchingMemberSerializationInfos.Count; i++) + { + var argument = arguments[i]; + var matchingMemberSerializationInfo = matchingMemberSerializationInfos[i]; + + if (IsNotKnown(argument)) + { + // arg => arg: matchingMemberSerializer + AddNodeSerializer(argument, matchingMemberSerializationInfo.Serializer); + } + } + } + } + + base.VisitNew(node); + + if (IsNotKnown(node)) + { + nodeSerializer = CreateSerializer(constructor); + if (nodeSerializer != null) + { + AddNodeSerializer(node, nodeSerializer); + } + } + + return node; + + IBsonSerializer CreateSerializer(ConstructorInfo constructor) + { + if (constructor == null) + { + return CreateNewExpressionSerializer(node, node, bindings: null); + } + else if (constructor.DeclaringType == typeof(BsonDocument)) + { + return BsonDocumentSerializer.Instance; + } + else if (constructor.DeclaringType == typeof(BsonValue)) + { + return BsonValueSerializer.Instance; + } + else if (constructor.DeclaringType == typeof(DateTime)) + { + return DateTimeSerializer.Instance; + } + else if (DictionaryConstructor.IsWithIEnumerableKeyValuePairConstructor(constructor)) + { + var collectionExpression = arguments[0]; + if (IsItemSerializerKnown(collectionExpression, out var itemSerializer) && + itemSerializer.IsKeyValuePairSerializer(out _, out _, out var keySerializer, out var valueSerializer)) + { + return DictionarySerializer.Create(DictionaryRepresentation.Document, keySerializer, valueSerializer); + } + } + else if (HashSetConstructor.IsWithCollectionConstructor(constructor)) + { + var collectionExpression = arguments[0]; + if (IsItemSerializerKnown(collectionExpression, out var itemSerializer)) + { + return HashSetSerializer.Create(itemSerializer); + } + } + else if (ListConstructor.IsWithCollectionConstructor(constructor)) + { + var collectionExpression = arguments[0]; + if (IsItemSerializerKnown(collectionExpression, out var itemSerializer)) + { + return ListSerializer.Create(itemSerializer); + } + } + else if (KeyValuePairConstructor.IsWithKeyAndValueConstructor(constructor)) + { + var key = arguments[0]; + var value = arguments[1]; + if (IsKnown(key, out var keySerializer) && + IsKnown(value, out var valueSerializer)) + { + return KeyValuePairSerializer.Create(BsonType.Document, keySerializer, valueSerializer); + } + } + else if (TupleOrValueTupleConstructor.IsTupleOrValueTupleConstructor(constructor)) + { + if (AreAllKnown(arguments, out var argumentSerializers)) + { + return TupleOrValueTupleSerializer.Create(constructor.DeclaringType, argumentSerializers); + } + } + else + { + return CreateNewExpressionSerializer(node, node, bindings: null); + } + + return null; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitNewArray.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitNewArray.cs new file mode 100644 index 00000000000..5c9d30d8946 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitNewArray.cs @@ -0,0 +1,146 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitNewArray(NewArrayExpression node) + { + DeduceNewArraySerializers(); + base.VisitNewArray(node); + DeduceNewArraySerializers(); + + return node; + + void DeduceNewArraySerializers() + { + switch (node.NodeType) + { + case ExpressionType.NewArrayBounds: + DeduceNewArrayBoundsSerializers(); + break; + + case ExpressionType.NewArrayInit: + DeduceNewArrayInitSerializers(); + break; + } + } + + void DeduceNewArrayBoundsSerializers() + { + throw new NotImplementedException(); + } + + void DeduceNewArrayInitSerializers() + { + var itemExpressions = node.Expressions; + IBsonSerializer itemSerializer; + + if (IsAnyNotKnown(itemExpressions) && IsKnown(node, out var arraySerializer)) + { + if (arraySerializer is IPolymorphicArraySerializer polymorphicArraySerializer) + { + for (var i = 0; i < itemExpressions.Count; i++) + { + var itemExpression = itemExpressions[i]; + if (IsNotKnown(itemExpression)) + { + itemSerializer = polymorphicArraySerializer.GetItemSerializer(i); + AddNodeSerializer(itemExpression, itemSerializer); + } + } + } + else + { + itemSerializer = arraySerializer.GetItemSerializer(); + foreach (var itemExpression in itemExpressions) + { + if (IsNotKnown(itemExpression)) + { + AddNodeSerializer(itemExpression, itemSerializer); + } + } + } + } + + if (IsAnyNotKnown(itemExpressions) && IsAnyKnown(itemExpressions, out itemSerializer)) + { + var firstItemType = itemExpressions.First().Type; + if (itemExpressions.All(e => e.Type == firstItemType)) + { + foreach (var itemExpression in itemExpressions) + { + if (IsNotKnown(itemExpression)) + { + AddNodeSerializer(itemExpression, itemSerializer); + } + } + } + } + + if (IsNotKnown(node)) + { + if (AreAllKnown(itemExpressions, out var itemSerializers)) + { + if (AllItemSerializersAreEqual(itemSerializers, out itemSerializer)) + { + arraySerializer = ArraySerializer.Create(itemSerializer); + } + else + { + var itemType = node.Type.GetElementType(); + arraySerializer = PolymorphicArraySerializer.Create(itemType, itemSerializers); + } + AddNodeSerializer(node, arraySerializer); + } + } + + static bool AllItemSerializersAreEqual(IReadOnlyList itemSerializers, out IBsonSerializer itemSerializer) + { + switch (itemSerializers.Count) + { + case 0: + itemSerializer = null; + return false; + case 1: + itemSerializer = itemSerializers[0]; + return true; + default: + var firstItemSerializer = itemSerializers[0]; + if (itemSerializers.Skip(1).All(s => s.Equals(firstItemSerializer))) + { + itemSerializer = firstItemSerializer; + return true; + } + else + { + itemSerializer = null; + return false; + } + } + } + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitTypeBinary.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitTypeBinary.cs new file mode 100644 index 00000000000..40ec74177ab --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitTypeBinary.cs @@ -0,0 +1,30 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitTypeBinary(TypeBinaryExpression node) + { + base.VisitTypeBinary(node); + + DeduceBooleanSerializer(node); + + return node; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitUnary.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitUnary.cs new file mode 100644 index 00000000000..96418e73305 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitUnary.cs @@ -0,0 +1,306 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitUnary(UnaryExpression node) + { + var unaryOperator = node.NodeType; + var operand = node.Operand; + + base.VisitUnary(node); + + switch (unaryOperator) + { + case ExpressionType.Negate: + DeduceNegateSerializers(); // TODO: fold into general case? + break; + + default: + DeduceUnaryOperatorSerializers(); + break; + } + + return node; + + void DeduceNegateSerializers() + { + DeduceSerializers(node, operand); + } + + void DeduceUnaryOperatorSerializers() + { + if (IsNotKnown(node)) + { + var resultSerializer = unaryOperator switch + { + ExpressionType.ArrayLength => Int32Serializer.Instance, + ExpressionType.Convert or ExpressionType.TypeAs => GetConvertSerializer(), + ExpressionType.Not => StandardSerializers.GetSerializer(node.Type), + ExpressionType.Quote => IgnoreNodeSerializer.Create(node.Type), + _ => null + }; + + if (resultSerializer != null) + { + AddNodeSerializer(node, resultSerializer); + } + } + } + + IBsonSerializer GetConvertSerializer() + { + var sourceType = operand.Type; + var targetType = node.Type; + + // handle double conversion (BsonValue)(object)x + if (targetType == typeof(BsonValue) && + operand is UnaryExpression unarySourceExpression && + unarySourceExpression.NodeType == ExpressionType.Convert && + unarySourceExpression.Type == typeof(object)) + { + operand = unarySourceExpression.Operand; + } + + if (IsKnown(operand, out var sourceSerializer)) + { + return GetTargetSerializer(node, sourceType, targetType, sourceSerializer); + } + + return null; + + static IBsonSerializer GetTargetSerializer(UnaryExpression node, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + if (targetType == sourceType) + { + return sourceSerializer; + } + + // handle conversion to BsonValue before any others + if (targetType == typeof(BsonValue)) + { + return GetConvertToBsonValueSerializer(node, sourceSerializer); + } + + // from Nullable must be handled before to Nullable + if (IsConvertFromNullableType(sourceType)) + { + return GetConvertFromNullableTypeSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertToNullableType(targetType, out var valueType)) + { + var valueSerializer = valueType == targetType ? sourceSerializer : GetTargetSerializer(node, sourceType, valueType, sourceSerializer); + return valueSerializer != null ? GetConvertToNullableTypeSerializer(node, sourceType, targetType, valueSerializer) : null; + } + + // from here on we know there are no longer any Nullable types involved + + if (sourceType == typeof(BsonValue)) + { + return GetConvertFromBsonValueSerializer(node, targetType); + } + + if (IsConvertEnumToUnderlyingType(sourceType, targetType)) + { + return GetConvertEnumToUnderlyingTypeSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertUnderlyingTypeToEnum(sourceType, targetType)) + { + return GetConvertUnderlyingTypeToEnumSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertEnumToEnum(sourceType, targetType)) + { + return GetConvertEnumToEnumSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertToBaseType(sourceType, targetType)) + { + return GetConvertToBaseTypeSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertToDerivedType(sourceType, targetType)) + { + return GetConvertToDerivedTypeSerializer(node, targetType, sourceSerializer); + } + + if (IsNumericConversion(sourceType, targetType)) + { + return GetNumericConversionSerializer(node, sourceType, targetType, sourceSerializer); + } + + return null; + } + + static IBsonSerializer GetConvertFromBsonValueSerializer(UnaryExpression expression, Type targetType) + { + return targetType switch + { + _ when targetType == typeof(string) => StringSerializer.Instance, + _ => throw new ExpressionNotSupportedException(expression, because: $"conversion from BsonValue to {targetType} is not supported") + }; + } + + static IBsonSerializer GetConvertToBaseTypeSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + var derivedTypeSerializer = sourceSerializer; + return DowncastingSerializer.Create(targetType, sourceType, derivedTypeSerializer); + } + + static IBsonSerializer GetConvertToDerivedTypeSerializer(UnaryExpression expression, Type targetType, IBsonSerializer sourceSerializer) + { + var derivedTypeSerializer = sourceSerializer.GetDerivedTypeSerializer(targetType); + return derivedTypeSerializer; + } + + static IBsonSerializer GetConvertToBsonValueSerializer(UnaryExpression expression, IBsonSerializer sourceSerializer) + { + return BsonValueSerializer.Instance; + } + + static IBsonSerializer GetConvertEnumToEnumSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + if (!sourceType.IsEnum) + { + throw new ExpressionNotSupportedException(expression, because: "source type is not an enum"); + } + if (!targetType.IsEnum) + { + throw new ExpressionNotSupportedException(expression, because: "target type is not an enum"); + } + + return EnumSerializer.Create(targetType); + } + + static IBsonSerializer GetConvertEnumToUnderlyingTypeSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + var enumSerializer = sourceSerializer; + return AsEnumUnderlyingTypeSerializer.Create(enumSerializer); + } + + static IBsonSerializer GetConvertFromNullableTypeSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + if (sourceSerializer is not INullableSerializer nullableSourceSerializer) + { + throw new ExpressionNotSupportedException(expression, because: $"sourceSerializer type {sourceSerializer.GetType()} does not implement nameof(INullableSerializer)"); + } + + var sourceValueSerializer = nullableSourceSerializer.ValueSerializer; + var sourceValueType = sourceValueSerializer.ValueType; + + if (targetType.IsNullable(out var targetValueType)) + { + var targetValueSerializer = GetTargetSerializer(expression, sourceValueType, targetValueType, sourceValueSerializer); + return NullableSerializer.Create(targetValueSerializer); + } + else + { + return GetTargetSerializer(expression, sourceValueType, targetType, sourceValueSerializer); + } + } + + static IBsonSerializer GetConvertToNullableTypeSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + if (sourceType.IsNullable()) + { + throw new ExpressionNotSupportedException(expression, because: "sourceType is already nullable"); + } + + if (targetType.IsNullable()) + { + return NullableSerializer.Create(sourceSerializer); + } + + throw new ExpressionNotSupportedException(expression, because: "targetType is not nullable"); + } + + static IBsonSerializer GetConvertUnderlyingTypeToEnumSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + IBsonSerializer targetSerializer; + if (sourceSerializer is IAsEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) + { + targetSerializer = enumUnderlyingTypeSerializer.EnumSerializer; + } + else + { + targetSerializer = EnumSerializer.Create(targetType); + } + + return targetSerializer; + } + + static IBsonSerializer GetNumericConversionSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + return NumericConversionSerializer.Create(sourceType, targetType, sourceSerializer); + } + + static bool IsConvertEnumToEnum(Type sourceType, Type targetType) + { + return sourceType.IsEnum && targetType.IsEnum; + } + + static bool IsConvertEnumToUnderlyingType(Type sourceType, Type targetType) + { + return + sourceType.IsEnum(out var underlyingType) && + targetType == underlyingType; + } + + static bool IsConvertFromNullableType(Type sourceType) + { + return sourceType.IsNullable(); + } + + static bool IsConvertToBaseType(Type sourceType, Type targetType) + { + return sourceType.IsSubclassOf(targetType) || sourceType.ImplementsInterface(targetType); + } + + static bool IsConvertToDerivedType(Type sourceType, Type targetType) + { + return sourceType.IsAssignableFrom(targetType); // targetType either derives from sourceType or implements sourceType interface + } + + static bool IsConvertToNullableType(Type targetType, out Type valueType) + { + return targetType.IsNullable(out valueType); + } + + static bool IsConvertUnderlyingTypeToEnum(Type sourceType, Type targetType) + { + return + targetType.IsEnum(out var underlyingType) && + sourceType == underlyingType; + } + + static bool IsNumericConversion(Type sourceType, Type targetType) + { + return sourceType.IsNumeric() && targetType.IsNumeric(); + } + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitor.cs new file mode 100644 index 00000000000..1b6c25e1238 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitor.cs @@ -0,0 +1,71 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using ExpressionVisitor = System.Linq.Expressions.ExpressionVisitor; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor : ExpressionVisitor +{ + private bool _isMakingProgress = true; + private readonly SerializerMap _nodeSerializers; + private int _oldNodeSerializersCount = 0; + private readonly ExpressionTranslationOptions _translationOptions; + private bool _useDefaultSerializerForConstants = false; // make as much progress as possible before setting this to true + + public SerializerFinderVisitor(ExpressionTranslationOptions translationOptions, SerializerMap nodeSerializers) + { + _nodeSerializers = nodeSerializers; + _translationOptions = translationOptions; + } + + public bool IsMakingProgress => _isMakingProgress; + + public void EndPass() + { + var newNodeSerializersCount = _nodeSerializers.Count; + if (newNodeSerializersCount == _oldNodeSerializersCount) + { + if (_useDefaultSerializerForConstants) + { + _isMakingProgress = false; + } + else + { + _useDefaultSerializerForConstants = true; + } + } + } + + public void StartPass() + { + _oldNodeSerializersCount = _nodeSerializers.Count; + } + + public override Expression Visit(Expression node) + { + if (IsKnown(node, out var nodeSerializer)) + { + if (nodeSerializer is IIgnoreSubtreeSerializer or IUnknowableSerializer) + { + return node; // don't visit subtree + } + } + + return base.Visit(node); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerMap.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerMap.cs new file mode 100644 index 00000000000..8b35de67223 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerMap.cs @@ -0,0 +1,111 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal interface IReadOnlySerializerMap +{ + IBsonSerializer GetSerializer(Expression node); +} + +internal class SerializerMap : IReadOnlySerializerMap +{ + private readonly Dictionary _map = new(); + + public int Count => _map.Count; + + public void AddSerializer(Expression node, IBsonSerializer serializer) + { + if (serializer.ValueType != node.Type && + node.Type.IsNullable(out var nodeNonNullableType) && + serializer.ValueType.IsNullable(out var serializerNonNullableType) && + serializer is INullableSerializer nullableSerializer) + { + if (nodeNonNullableType.IsEnum(out var targetEnumUnderlyingType) && targetEnumUnderlyingType == serializerNonNullableType) + { + var enumType = nodeNonNullableType; + var underlyingTypeSerializer = nullableSerializer.ValueSerializer; + var enumSerializer = AsUnderlyingTypeEnumSerializer.Create(enumType, underlyingTypeSerializer); + serializer = NullableSerializer.Create(enumSerializer); + } + else if (serializerNonNullableType.IsEnum(out var serializerUnderlyingType) && serializerUnderlyingType == nodeNonNullableType) + { + var enumSerializer = nullableSerializer.ValueSerializer; + var underlyingTypeSerializer = AsEnumUnderlyingTypeSerializer.Create(enumSerializer); + serializer = NullableSerializer.Create(underlyingTypeSerializer); + } + } + + if (serializer.ValueType != node.Type) + { + if (node.Type.IsAssignableFrom(serializer.ValueType)) + { + serializer = DowncastingSerializer.Create(baseType: node.Type, derivedType: serializer.ValueType, derivedTypeSerializer: serializer); + } + else if (serializer.ValueType.IsAssignableFrom(node.Type)) + { + serializer = UpcastingSerializer.Create(baseType: serializer.ValueType, derivedType: node.Type, baseTypeSerializer: serializer); + } + else + { + throw new ArgumentException($"Serializer value type {serializer.ValueType} does not match expression value type {node.Type}", nameof(serializer)); + } + } + + if (_map.TryGetValue(node, out var existingSerializer)) + { + throw new ExpressionNotSupportedException( + node, + because: $"there are duplicate known serializers for expression '{node}': {serializer.GetType()} and {existingSerializer.GetType()}"); + } + + _map.Add(node, serializer); + } + + public IBsonSerializer GetSerializer(Expression node) + { + if (_map.TryGetValue(node, out var nodeSerializer)) + { + return nodeSerializer; + } + + throw new ExpressionNotSupportedException(node, because: "unable to determine which serializer to use"); + } + + public bool IsNotKnown(Expression node) + { + return !IsKnown(node); + } + + public bool IsKnown(Expression node) + { + return _map.ContainsKey(node); + } + + public bool IsKnown(Expression node, out IBsonSerializer serializer) + { + serializer = null; + return node != null && _map.TryGetValue(node, out serializer); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializer.cs similarity index 63% rename from src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializer.cs rename to src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializer.cs index 816e5fc237f..7e0b3d1e75c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializer.cs @@ -20,24 +20,24 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers { - internal interface IEnumUnderlyingTypeSerializer + internal interface IAsEnumUnderlyingTypeSerializer { IBsonSerializer EnumSerializer { get; } } - internal class EnumUnderlyingTypeSerializer : StructSerializerBase, IEnumUnderlyingTypeSerializer + internal class AsEnumUnderlyingTypeSerializer : StructSerializerBase, IAsEnumUnderlyingTypeSerializer where TEnum : Enum - where TEnumUnderlyingType : struct + where TUnderlyingType : struct { // private fields private readonly IBsonSerializer _enumSerializer; // constructors - public EnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) + public AsEnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) { - if (typeof(TEnumUnderlyingType) != Enum.GetUnderlyingType(typeof(TEnum))) + if (typeof(TUnderlyingType) != Enum.GetUnderlyingType(typeof(TEnum))) { - throw new ArgumentException($"{typeof(TEnumUnderlyingType).FullName} is not the underlying type of {typeof(TEnum).FullName}."); + throw new ArgumentException($"{typeof(TUnderlyingType).FullName} is not the underlying type of {typeof(TEnum).FullName}."); } _enumSerializer = Ensure.IsNotNull(enumSerializer, nameof(enumSerializer)); } @@ -46,13 +46,13 @@ public EnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) public IBsonSerializer EnumSerializer => _enumSerializer; // explicitly implemented properties - IBsonSerializer IEnumUnderlyingTypeSerializer.EnumSerializer => EnumSerializer; + IBsonSerializer IAsEnumUnderlyingTypeSerializer.EnumSerializer => EnumSerializer; // public methods - public override TEnumUnderlyingType Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + public override TUnderlyingType Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) { var enumValue = _enumSerializer.Deserialize(context); - return (TEnumUnderlyingType)(object)enumValue; + return (TUnderlyingType)(object)enumValue; } /// @@ -62,28 +62,28 @@ public override bool Equals(object obj) if (object.ReferenceEquals(this, obj)) { return true; } return base.Equals(obj) && - obj is EnumUnderlyingTypeSerializer other && + obj is AsEnumUnderlyingTypeSerializer other && object.Equals(_enumSerializer, other._enumSerializer); } /// public override int GetHashCode() => _enumSerializer.GetHashCode(); - public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TEnumUnderlyingType value) + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TUnderlyingType value) { var enumValue = (TEnum)(object)value; _enumSerializer.Serialize(context, enumValue); } } - internal static class EnumUnderlyingTypeSerializer + internal static class AsEnumUnderlyingTypeSerializer { public static IBsonSerializer Create(IBsonSerializer enumSerializer) { var enumType = enumSerializer.ValueType; var underlyingType = Enum.GetUnderlyingType(enumType); - var enumUnderlyingTypeSerializerType = typeof(EnumUnderlyingTypeSerializer<,>).MakeGenericType(enumType, underlyingType); - return (IBsonSerializer)Activator.CreateInstance(enumUnderlyingTypeSerializerType, enumSerializer); + var asEnumUnderlyingTypeSerializerType = typeof(AsEnumUnderlyingTypeSerializer<,>).MakeGenericType(enumType, underlyingType); + return (IBsonSerializer)Activator.CreateInstance(asEnumUnderlyingTypeSerializerType, enumSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsUnderlyingTypeEnumSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsUnderlyingTypeEnumSerializer.cs new file mode 100644 index 00000000000..41f673af856 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsUnderlyingTypeEnumSerializer.cs @@ -0,0 +1,88 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Core.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers +{ + internal interface IAsUnderlyingTypeEnumSerializer + { + IBsonSerializer UnderlyingTypeSerializer { get; } + } + + internal class AsUnderlyingTypeEnumSerializer : SerializerBase, IAsUnderlyingTypeEnumSerializer + where TEnum : Enum + where TUnderlyingType : struct + { + // private fields + private readonly IBsonSerializer _underlyingTypeSerializer; + + // constructors + public AsUnderlyingTypeEnumSerializer(IBsonSerializer underlyingTypeSerializer) + { + if (typeof(TUnderlyingType) != Enum.GetUnderlyingType(typeof(TEnum))) + { + throw new ArgumentException($"{typeof(TUnderlyingType).FullName} is not the underlying type of {typeof(TEnum).FullName}."); + } + _underlyingTypeSerializer = Ensure.IsNotNull(underlyingTypeSerializer, nameof(underlyingTypeSerializer)); + } + + // public properties + public IBsonSerializer UnderlyingTypeSerializer => _underlyingTypeSerializer; + + // explicitly implemented properties + IBsonSerializer IAsUnderlyingTypeEnumSerializer.UnderlyingTypeSerializer => UnderlyingTypeSerializer; + + // public methods + public override TEnum Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var underlyingTypeValue = _underlyingTypeSerializer.Deserialize(context); + return (TEnum)(object)underlyingTypeValue; + } + + /// + public override bool Equals(object obj) + { + if (object.ReferenceEquals(obj, null)) { return false; } + if (object.ReferenceEquals(this, obj)) { return true; } + return + base.Equals(obj) && + obj is AsUnderlyingTypeEnumSerializer other && + object.Equals(_underlyingTypeSerializer, other._underlyingTypeSerializer); + } + + /// + public override int GetHashCode() => _underlyingTypeSerializer.GetHashCode(); + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TEnum value) + { + var underlyingTypeValue = (TUnderlyingType)(object)value; + _underlyingTypeSerializer.Serialize(context, underlyingTypeValue); + } + } + + internal static class AsUnderlyingTypeEnumSerializer + { + public static IBsonSerializer Create(Type enumType, IBsonSerializer underlyingTypeSerializer) + { + var underlyingType = Enum.GetUnderlyingType(enumType); + var asUnderlyingTypeEnumSerializerType = typeof(AsUnderlyingTypeEnumSerializer<,>).MakeGenericType(enumType, underlyingType); + return (IBsonSerializer)Activator.CreateInstance(asUnderlyingTypeEnumSerializerType, underlyingTypeSerializer); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/DictionarySerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/DictionarySerializer.cs index bfecb1ef9c7..0d56847649e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/DictionarySerializer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/DictionarySerializer.cs @@ -45,8 +45,7 @@ public DictionarySerializer( { } - protected override ICollection> CreateAccumulator() - { - return new Dictionary(); - } + protected override ICollection> CreateAccumulator() => new Dictionary(); + + protected override DictionaryFinalizeAccumulator(ICollection> accumulator) => (Dictionary)accumulator; } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/HashSetSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/HashSetSerializer.cs new file mode 100644 index 00000000000..87a47747e5f --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/HashSetSerializer.cs @@ -0,0 +1,42 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class HashSetSerializer +{ + public static IBsonSerializer Create(IBsonSerializer itemSerializer) + { + var serializerType = typeof(HashSetSerializer<>).MakeGenericType(itemSerializer.ValueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, itemSerializer); + } +} + +internal class HashSetSerializer : EnumerableInterfaceImplementerSerializerBase, T> +{ + public HashSetSerializer(IBsonSerializer itemSerializer) + : base(itemSerializer) + { + } + + protected override object CreateAccumulator() => new HashSet(); + + protected override HashSet FinalizeResult(object accumulator) => (HashSet)accumulator; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IEnumerableOrIQueryableSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IEnumerableOrIQueryableSerializer.cs new file mode 100644 index 00000000000..f03bf327711 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IEnumerableOrIQueryableSerializer.cs @@ -0,0 +1,30 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IEnumerableOrIQueryableSerializer +{ + public static IBsonSerializer Create(Type enumerableOrQueryableType, IBsonSerializer itemSerializer) + { + return enumerableOrQueryableType.ImplementsIQueryable(out _) ? + IQueryableSerializer.Create(itemSerializer) : + IEnumerableSerializer.Create(itemSerializer); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IOrderedEnumerableOrIOrderedQueryableSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IOrderedEnumerableOrIOrderedQueryableSerializer.cs new file mode 100644 index 00000000000..da44f92e218 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IOrderedEnumerableOrIOrderedQueryableSerializer.cs @@ -0,0 +1,30 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IOrderedEnumerableOrIOrderedQueryableSerializer +{ + public static IBsonSerializer Create(Type enumerableOrQueryableType, IBsonSerializer itemSerializer) + { + return enumerableOrQueryableType.ImplementsIOrderedQueryable(out _) ? + IOrderedQueryableSerializer.Create(itemSerializer) : + IOrderedEnumerableSerializer.Create(itemSerializer); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ISetWindowFieldsPartitionSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ISetWindowFieldsPartitionSerializer.cs index 2be9f49a1b3..b169febe181 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ISetWindowFieldsPartitionSerializer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ISetWindowFieldsPartitionSerializer.cs @@ -24,7 +24,7 @@ internal interface ISetWindowFieldsPartitionSerializer IBsonSerializer InputSerializer { get; } } - internal class ISetWindowFieldsPartitionSerializer : IBsonSerializer>, ISetWindowFieldsPartitionSerializer + internal class ISetWindowFieldsPartitionSerializer : IBsonSerializer>, ISetWindowFieldsPartitionSerializer, IBsonArraySerializer { private readonly IBsonSerializer _inputSerializer; @@ -61,16 +61,20 @@ public void Serialize(BsonSerializationContext context, BsonSerializationArgs ar throw new InvalidOperationException("This serializer is not intended to be used."); } - public void Serialize(BsonSerializationContext context, BsonSerializationArgs args, object value) { throw new InvalidOperationException("This serializer is not intended to be used."); } - object IBsonSerializer.Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) { throw new InvalidOperationException("This serializer is not intended to be used."); } + + public bool TryGetItemSerializationInfo(out BsonSerializationInfo itemSerializationInfo) + { + itemSerializationInfo = new BsonSerializationInfo(null, _inputSerializer, _inputSerializer.ValueType); + return true; + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreNodeSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreNodeSerializer.cs new file mode 100644 index 00000000000..23fb02f7db8 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreNodeSerializer.cs @@ -0,0 +1,33 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IgnoreNodeSerializer +{ + public static IBsonSerializer Create(Type valueType) + { + var serializerType = typeof(IgnoreNodeSerializer<>).MakeGenericType(valueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType); + } +} + +internal class IgnoreNodeSerializer : SerializerBase +{ +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreSubtreeSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreSubtreeSerializer.cs new file mode 100644 index 00000000000..5476eb1e747 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreSubtreeSerializer.cs @@ -0,0 +1,37 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IgnoreSubtreeSerializer +{ + public static IBsonSerializer Create(Type valueType) + { + var serializerType = typeof(IgnoreSubtreeSerializer<>).MakeGenericType(valueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType); + } +} + +internal interface IIgnoreSubtreeSerializer +{ +} + +internal class IgnoreSubtreeSerializer : SerializerBase, IIgnoreSubtreeSerializer +{ +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ListSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ListSerializer.cs new file mode 100644 index 00000000000..2a7044e7116 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ListSerializer.cs @@ -0,0 +1,42 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class ListSerializer +{ + public static IBsonSerializer Create(IBsonSerializer itemSerializer) + { + var serializerType = typeof(ListSerializer<>).MakeGenericType(itemSerializer.ValueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, itemSerializer); + } +} + +internal class ListSerializer : EnumerableInterfaceImplementerSerializerBase, T> +{ + public ListSerializer(IBsonSerializer itemSerializer) + : base(itemSerializer) + { + } + + protected override object CreateAccumulator() => new List(); + + protected override List FinalizeResult(object accumulator) => (List)accumulator; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/NumericConversionSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/NumericConversionSerializer.cs new file mode 100644 index 00000000000..c09e78a713c --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/NumericConversionSerializer.cs @@ -0,0 +1,77 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class NumericConversionSerializer +{ + public static IBsonSerializer Create(Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + var serializerType = typeof(NumericConversionSerializer<,>).MakeGenericType(sourceType, targetType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, sourceSerializer); + } +} + +internal class NumericConversionSerializer : SerializerBase, IHasRepresentationSerializer +{ + private readonly BsonType _representation; + private readonly IBsonSerializer _sourceSerializer; + + public BsonType Representation => _representation; + + public NumericConversionSerializer(IBsonSerializer sourceSerializer) + { + if (sourceSerializer is not IHasRepresentationSerializer hasRepresentationSerializer) + { + throw new NotSupportedException($"Serializer class {sourceSerializer.GetType().Name} does not implement IHasRepresentationSerializer."); + } + + _sourceSerializer = sourceSerializer; + _representation = hasRepresentationSerializer.Representation; + } + + public override TTarget Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var sourceValue = _sourceSerializer.Deserialize(context); + return (TTarget)Convert(typeof(TSource), typeof(TTarget), sourceValue); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TTarget value) + { + var sourceValue = Convert(typeof(TTarget), typeof(TSource), value); + _sourceSerializer.Serialize(context, args, sourceValue); + } + + private object Convert(Type sourceType, Type targetType, object value) + { + return (Type.GetTypeCode(sourceType), Type.GetTypeCode(targetType)) switch + { + (TypeCode.Decimal, TypeCode.Double) => (object)(double)(decimal)value, + (TypeCode.Double, TypeCode.Decimal) => (object)(decimal)(double)value, + (TypeCode.Int16, TypeCode.Int32) => (object)(int)(short)value, + (TypeCode.Int16, TypeCode.Int64) => (object)(long)(short)value, + (TypeCode.Int32, TypeCode.Int16) => (object)(short)(int)value, + (TypeCode.Int32, TypeCode.Int64) => (object)(long)(int)value, + (TypeCode.Int64, TypeCode.Int16) => (object)(short)(long)value, + (TypeCode.Int64, TypeCode.Int32) => (object)(int)(long)value, + _ => throw new NotSupportedException($"Cannot convert {sourceType} to {targetType}."), + }; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/PolymorphicArraySerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/PolymorphicArraySerializer.cs new file mode 100644 index 00000000000..beb65eee63e --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/PolymorphicArraySerializer.cs @@ -0,0 +1,98 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal interface IPolymorphicArraySerializer +{ + IBsonSerializer GetItemSerializer(int index); +} + +internal static class PolymorphicArraySerializer +{ + public static IBsonSerializer Create(Type itemType, IEnumerable itemSerializers) + { + var serializerType = typeof(PolymorphicArraySerializer<>).MakeGenericType(itemType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, itemSerializers); + } +} + +internal sealed class PolymorphicArraySerializer : SerializerBase, IPolymorphicArraySerializer +{ + private readonly IReadOnlyList _itemSerializers; + + public PolymorphicArraySerializer(IEnumerable itemSerializers) + { + var itemSerializersArray = itemSerializers.ToArray(); + foreach (var itemSerializer in itemSerializersArray) + { + if (!typeof(TItem).IsAssignableFrom(itemSerializer.ValueType)) + { + throw new ArgumentException($"Serializer class {itemSerializer.ValueType} value type is not assignable to item type {typeof(TItem).Name}"); + } + } + + _itemSerializers = itemSerializersArray; + } + + public override TItem[] Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var reader = context.Reader; + + reader.ReadStartArray(); + var i = 0; + var array = new TItem[_itemSerializers.Count]; + while (reader.ReadBsonType() != BsonType.EndOfDocument) + { + if (i < array.Length) + { + array[i] = (TItem)_itemSerializers[i].Deserialize(context); + i++; + } + } + if (i != array.Length) + { + throw new BsonSerializationException($"Expected {array.Length} array items but found {i}."); + } + reader.ReadEndArray(); + + return array; + } + + IBsonSerializer IPolymorphicArraySerializer.GetItemSerializer(int index) => _itemSerializers[index]; + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TItem[] value) + { + if (value.Length != _itemSerializers.Count) + { + throw new BsonSerializationException($"Expected array value to have {_itemSerializers.Count} items but found {value.Length}."); + } + + var writer = context.Writer; + writer.WriteStartArray(); + for (var i = 0; i < value.Length; i++) + { + _itemSerializers[i].Serialize(context, args, value[i]); + } + writer.WriteEndArray(); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/TupleOrValueTupleSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/TupleOrValueTupleSerializer.cs new file mode 100644 index 00000000000..762b2839ee8 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/TupleOrValueTupleSerializer.cs @@ -0,0 +1,35 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class TupleOrValueTupleSerializer +{ + public static IBsonSerializer Create(Type tupleType, IEnumerable itemSerializers) + { + return tupleType.Name switch + { + _ when tupleType.IsTuple() => TupleSerializer.Create(itemSerializers), + _ when tupleType.IsValueTuple() => ValueTupleSerializer.Create(itemSerializers), + _ => throw new ArgumentException($"Unexpected tuple type: {tupleType.Name}", nameof(tupleType)) + }; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UnknowableSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UnknowableSerializer.cs new file mode 100644 index 00000000000..e3e6583408b --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UnknowableSerializer.cs @@ -0,0 +1,37 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class UnknowableSerializer +{ + public static IBsonSerializer Create(Type valueType) + { + var serializerType = typeof(UnknowableSerializer<>).MakeGenericType(valueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType); + } +} + +internal interface IUnknowableSerializer +{ +} + +internal class UnknowableSerializer : SerializerBase, IUnknowableSerializer +{ +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UpcastingSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UpcastingSerializer.cs new file mode 100644 index 00000000000..e2843cb8602 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UpcastingSerializer.cs @@ -0,0 +1,92 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers +{ + internal static class UpcastingSerializer + { + public static IBsonSerializer Create( + Type baseType, + Type derivedType, + IBsonSerializer baseTypeSerializer) + { + var upcastingSerializerType = typeof(UpcastingSerializer<,>).MakeGenericType(baseType, derivedType); + return (IBsonSerializer)Activator.CreateInstance(upcastingSerializerType, baseTypeSerializer); + } + } + + internal sealed class UpcastingSerializer : SerializerBase, IBsonArraySerializer, IBsonDocumentSerializer + where TDerived : TBase + { + private readonly IBsonSerializer _baseTypeSerializer; + + public UpcastingSerializer(IBsonSerializer baseTypeSerializer) + { + _baseTypeSerializer = baseTypeSerializer ?? throw new ArgumentNullException(nameof(baseTypeSerializer)); + } + + public Type BaseType => typeof(TBase); + + public IBsonSerializer BaseTypeSerializer => _baseTypeSerializer; + + public Type DerivedType => typeof(TDerived); + + public override TDerived Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + return (TDerived)_baseTypeSerializer.Deserialize(context); + } + + public override bool Equals(object obj) + { + if (object.ReferenceEquals(obj, null)) { return false; } + if (object.ReferenceEquals(this, obj)) { return true; } + return + base.Equals(obj) && + obj is UpcastingSerializer other && + object.Equals(_baseTypeSerializer, other._baseTypeSerializer); + } + + public override int GetHashCode() => 0; + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TDerived value) + { + _baseTypeSerializer.Serialize(context, value); + } + + public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationInfo) + { + if (_baseTypeSerializer is not IBsonArraySerializer arraySerializer) + { + throw new NotSupportedException($"The class {_baseTypeSerializer.GetType().FullName} does not implement IBsonArraySerializer."); + } + + return arraySerializer.TryGetItemSerializationInfo(out serializationInfo); + } + + public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) + { + if (_baseTypeSerializer is not IBsonDocumentSerializer documentSerializer) + { + throw new NotSupportedException($"The class {_baseTypeSerializer.GetType().FullName} does not implement IBsonDocumentSerializer."); + } + + return documentSerializer.TryGetMemberSerializationInfo(memberName, out serializationInfo); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs index f3bb40aaf3a..c66f84b213e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs @@ -98,6 +98,20 @@ public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationI public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) { + if (_valueSerializer is IBsonDocumentSerializer documentSerializer) + { + if (documentSerializer.TryGetMemberSerializationInfo(memberName, out serializationInfo)) + { + serializationInfo = BsonSerializationInfo.CreateWithPath( + [_fieldName, serializationInfo.ElementName], + serializationInfo.Serializer, + serializationInfo.NominalType); + return true; + } + + return false; + } + throw new InvalidOperationException(); } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ArrayIndexExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ArrayIndexExpressionToAggregationExpressionTranslator.cs index 818a92fab7a..3462d1bcf3e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ArrayIndexExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ArrayIndexExpressionToAggregationExpressionTranslator.cs @@ -14,8 +14,11 @@ */ using System.Linq.Expressions; +using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -30,7 +33,8 @@ public static TranslatedExpression Translate(TranslationContext context, BinaryE var indexExpression = expression.Right; var indexTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, indexExpression); var ast = AstExpression.ArrayElemAt(arrayTranslation.Ast, indexTranslation.Ast); - var itemSerializer = ArraySerializerHelper.GetItemSerializer(arrayTranslation.Serializer); + var arraySerializer = arrayTranslation.Serializer; + var itemSerializer = arraySerializer.GetItemSerializer(indexExpression, arrayExpression); return new TranslatedExpression(expression, ast, itemSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs index 7487627213d..33a0f678737 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs @@ -23,12 +23,11 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ConstantExpressionToAggregationExpressionTranslator { - public static TranslatedExpression Translate(ConstantExpression constantExpression) + public static TranslatedExpression Translate(TranslationContext context, ConstantExpression constantExpression) { - var constantType = constantExpression.Type; - var constantSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType); + var constantSerializer = context.NodeSerializers.GetSerializer(constantExpression); return Translate(constantExpression, constantSerializer); - } + } public static TranslatedExpression Translate(ConstantExpression constantExpression, IBsonSerializer constantSerializer) { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs index 532e10c1609..90cf9d8c45d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs @@ -214,7 +214,7 @@ private static TranslatedExpression TranslateConvertEnumToEnum(UnaryExpression e private static TranslatedExpression TranslateConvertEnumToUnderlyingType(UnaryExpression expression, Type sourceType, Type targetType, TranslatedExpression sourceTranslation) { var enumSerializer = sourceTranslation.Serializer; - var targetSerializer = EnumUnderlyingTypeSerializer.Create(enumSerializer); + var targetSerializer = AsEnumUnderlyingTypeSerializer.Create(enumSerializer); return new TranslatedExpression(expression, sourceTranslation.Ast, targetSerializer); } @@ -265,7 +265,7 @@ private static TranslatedExpression TranslateConvertUnderlyingTypeToEnum(UnaryEx var valueSerializer = sourceTranslation.Serializer; IBsonSerializer targetSerializer; - if (valueSerializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) + if (valueSerializer is IAsEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) { targetSerializer = enumUnderlyingTypeSerializer.EnumSerializer; } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs index 9f019682a63..077decb5f4c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs @@ -75,7 +75,7 @@ public static TranslatedExpression TranslateWithoutUnwrapping(TranslationContext case ExpressionType.Conditional: return ConditionalExpressionToAggregationExpressionTranslator.Translate(context, (ConditionalExpression)expression); case ExpressionType.Constant: - return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression); + return ConstantExpressionToAggregationExpressionTranslator.Translate(context, (ConstantExpression)expression); case ExpressionType.Index: return IndexExpressionToAggregationExpressionTranslator.Translate(context, (IndexExpression)expression); case ExpressionType.ListInit: diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs index 20f7e81312c..78170c98f61 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs @@ -13,16 +13,14 @@ * limitations under the License. */ -using System; using System.Collections.Generic; -using System.Linq; using System.Linq.Expressions; -using System.Reflection; using MongoDB.Bson; using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -44,168 +42,63 @@ public static TranslatedExpression Translate( NewExpression newExpression, IReadOnlyList bindings) { + var nodeSerializer = context.NodeSerializers.GetSerializer(expression); var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct var constructorArguments = newExpression.Arguments; - var computedFields = new List(); - var classMap = CreateClassMap(newExpression.Type, constructorInfo, out var creatorMap); - if (constructorInfo != null && creatorMap != null) + var computedFields = new List(); + if (constructorInfo != null && constructorArguments.Count > 0) { - var constructorParameters = constructorInfo.GetParameters(); - var creatorMapParameters = creatorMap.Arguments?.ToArray(); - if (constructorParameters.Length > 0) + var matchingMemberSerializationInfos = nodeSerializer.GetMatchingMemberSerializationInfosForConstructorParameters(expression, constructorInfo); + + for (var i = 0; i < constructorArguments.Count; i++) { - if (creatorMapParameters == null) - { - throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters."); - } - if (creatorMapParameters.Length != constructorParameters.Length) - { - throw new ExpressionNotSupportedException(expression, because: $"the constructor has {constructorParameters} parameters but the creatorMap has {creatorMapParameters.Length} parameters."); - } + var argument = constructorArguments[i]; + var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argument); + var matchingMemberSerializationInfo = matchingMemberSerializationInfos[i]; - for (var i = 0; i < creatorMapParameters.Length; i++) + if (!argumentTranslation.Serializer.CanBeAssignedTo(matchingMemberSerializationInfo.Serializer)) { - var creatorMapParameter = creatorMapParameters[i]; - var constructorArgumentExpression = constructorArguments[i]; - var constructorArgumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, constructorArgumentExpression); - var constructorArgumentType = constructorArgumentExpression.Type; - var constructorArgumentSerializer = constructorArgumentTranslation.Serializer ?? BsonSerializer.LookupSerializer(constructorArgumentType); - var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter); - EnsureDefaultValue(memberMap); - var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, constructorArgumentSerializer); - memberMap.SetSerializer(memberSerializer); - computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, constructorArgumentTranslation.Ast)); + throw new ExpressionNotSupportedException(argument, expression, because: "argument serializer is not equal to member serializer"); } - } - } - - foreach (var binding in bindings) - { - var memberAssignment = (MemberAssignment)binding; - var member = memberAssignment.Member; - var memberMap = FindMemberMap(expression, classMap, member.Name); - var valueExpression = memberAssignment.Expression; - var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); - var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueTranslation.Serializer); - memberMap.SetSerializer(memberSerializer); - computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, valueTranslation.Ast)); - } - - var ast = AstExpression.ComputedDocument(computedFields); - classMap.Freeze(); - var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(newExpression.Type); - var serializer = (IBsonSerializer)Activator.CreateInstance(serializerType, classMap); - - return new TranslatedExpression(expression, ast, serializer); - } - - private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap) - { - BsonClassMap baseClassMap = null; - if (classType.BaseType != null) - { - baseClassMap = CreateClassMap(classType.BaseType, null, out _); - } - - var classMapType = typeof(BsonClassMap<>).MakeGenericType(classType); - var classMapConstructorInfo = classMapType.GetConstructor(new Type[] { typeof(BsonClassMap) }); - var classMap = (BsonClassMap)classMapConstructorInfo.Invoke(new object[] { baseClassMap }); - if (constructorInfo != null) - { - creatorMap = classMap.MapConstructor(constructorInfo); - } - else - { - creatorMap = null; - } - - classMap.AutoMap(); - classMap.IdMemberMap?.SetElementName("_id"); // normally happens when Freeze is called but we need it sooner here - - return classMap; - } - - private static IBsonSerializer CoerceSourceSerializerToMemberSerializer(BsonMemberMap memberMap, IBsonSerializer sourceSerializer) - { - var memberType = memberMap.MemberType; - var memberSerializer = memberMap.GetSerializer(); - var sourceType = sourceSerializer.ValueType; - if (memberType != sourceType && - memberType.ImplementsIEnumerable(out var memberItemType) && - sourceType.ImplementsIEnumerable(out var sourceItemType) && - sourceItemType == memberItemType && - sourceSerializer is IBsonArraySerializer sourceArraySerializer && - sourceArraySerializer.TryGetItemSerializationInfo(out var sourceItemSerializationInfo) && - memberSerializer is IChildSerializerConfigurable memberChildSerializerConfigurable) - { - var sourceItemSerializer = sourceItemSerializationInfo.Serializer; - return memberChildSerializerConfigurable.WithChildSerializer(sourceItemSerializer); - } - - return sourceSerializer; - } - - private static BsonMemberMap EnsureMemberMap(Expression expression, BsonClassMap classMap, MemberInfo creatorMapParameter) - { - var declaringClassMap = classMap; - while (declaringClassMap.ClassType != creatorMapParameter.DeclaringType) - { - declaringClassMap = declaringClassMap.BaseClassMap; - - if (declaringClassMap == null) - { - throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching property for constructor parameter: {creatorMapParameter.Name}"); + var computedField = AstExpression.ComputedField(matchingMemberSerializationInfo.ElementName, argumentTranslation.Ast); + computedFields.Add(computedField); } } - foreach (var memberMap in declaringClassMap.DeclaredMemberMaps) + if (bindings.Count > 0) { - if (MemberMapMatchesCreatorMapParameter(memberMap, creatorMapParameter)) + if (nodeSerializer is not IBsonDocumentSerializer documentSerializer) { - return memberMap; + throw new ExpressionNotSupportedException(expression, because: $"serializer type {nodeSerializer.GetType()} does not implement IBsonDocumentSerializer"); } - } - return declaringClassMap.MapMember(creatorMapParameter); + foreach (var binding in bindings) + { + var memberAssignment = (MemberAssignment)binding; + var member = memberAssignment.Member; - static bool MemberMapMatchesCreatorMapParameter(BsonMemberMap memberMap, MemberInfo creatorMapParameter) - { - var memberInfo = memberMap.MemberInfo; - return - memberInfo.MemberType == creatorMapParameter.MemberType && - memberInfo.Name.Equals(creatorMapParameter.Name, StringComparison.OrdinalIgnoreCase); - } - } + if (!documentSerializer.TryGetMemberSerializationInfo(member.Name, out var memberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"member {member.Name} was not found"); + } - private static void EnsureDefaultValue(BsonMemberMap memberMap) - { - if (memberMap.IsDefaultValueSpecified) - { - return; - } + var valueExpression = memberAssignment.Expression; + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); - var defaultValue = memberMap.MemberType.IsValueType ? Activator.CreateInstance(memberMap.MemberType) : null; - memberMap.SetDefaultValue(defaultValue); - } + if (!valueTranslation.Serializer.CanBeAssignedTo(memberSerializationInfo.Serializer)) + { + throw new ExpressionNotSupportedException(valueExpression, expression, because: $"value serializer is not equal to serializer for member {member.Name}"); + } - private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) - { - foreach (var memberMap in classMap.DeclaredMemberMaps) - { - if (memberMap.MemberName == memberName) - { - return memberMap; + var computedField = AstExpression.ComputedField(memberSerializationInfo.ElementName, valueTranslation.Ast); + computedFields.Add(computedField); } } - if (classMap.BaseClassMap != null) - { - return FindMemberMap(expression, classMap.BaseClassMap, memberName); - } - - throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}"); + var ast = AstExpression.ComputedDocument(computedFields); + return new TranslatedExpression(expression, ast, nodeSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AbsMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AbsMethodToAggregationExpressionTranslator.cs index 36d0e03a6b4..56dc9607212 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AbsMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AbsMethodToAggregationExpressionTranslator.cs @@ -23,23 +23,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AbsMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __absMethods = - { - MathMethod.AbsDecimal, - MathMethod.AbsDouble, - MathMethod.AbsInt16, - MathMethod.AbsInt32, - MathMethod.AbsInt64, - MathMethod.AbsSByte, - MathMethod.AbsSingle - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__absMethods)) + if (method.IsOneOf(MathMethod.AbsOverloads)) { var valueExpression = arguments[0]; var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs index 8abb03ac872..959df291eb7 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs @@ -24,49 +24,19 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AggregateMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __aggregateMethods = - { - EnumerableMethod.AggregateWithFunc, - EnumerableMethod.AggregateWithSeedAndFunc, - EnumerableMethod.AggregateWithSeedFuncAndResultSelector, - QueryableMethod.AggregateWithFunc, - QueryableMethod.AggregateWithSeedAndFunc, - QueryableMethod.AggregateWithSeedFuncAndResultSelector - }; - - private static readonly MethodInfo[] __aggregateWithoutSeedMethods = - { - EnumerableMethod.AggregateWithFunc, - QueryableMethod.AggregateWithFunc - }; - - private static readonly MethodInfo[] __aggregateWithSeedMethods = - { - EnumerableMethod.AggregateWithSeedAndFunc, - EnumerableMethod.AggregateWithSeedFuncAndResultSelector, - QueryableMethod.AggregateWithSeedAndFunc, - QueryableMethod.AggregateWithSeedFuncAndResultSelector - }; - - private static readonly MethodInfo[] __aggregateWithSeedFuncAndResultSelectorMethods = - { - EnumerableMethod.AggregateWithSeedFuncAndResultSelector, - QueryableMethod.AggregateWithSeedFuncAndResultSelector - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__aggregateMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - if (method.IsOneOf(__aggregateWithoutSeedMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateWithFunc)) { var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); var funcParameters = funcLambda.Parameters; @@ -95,7 +65,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC return new TranslatedExpression(expression, ast, itemSerializer); } - else if (method.IsOneOf(__aggregateWithSeedMethods)) + else if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateWithSeedOverloads)) { var seedExpression = arguments[1]; var seedTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, seedExpression); @@ -116,7 +86,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC @in: funcTranslation.Ast); var serializer = accumulatorSerializer; - if (method.IsOneOf(__aggregateWithSeedFuncAndResultSelectorMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateWithSeedFuncAndResultSelector)) { var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); var resultSelectorAccumulatorParameter = resultSelectorLambda.Parameters[0]; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AllMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AllMethodToAggregationExpressionTranslator.cs index 290f49185a0..f10a7bcf416 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AllMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AllMethodToAggregationExpressionTranslator.cs @@ -24,18 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AllMethodToAggregationExpressionTranslator { - private readonly static MethodInfo[] __allMethods = - { - EnumerableMethod.All, - QueryableMethod.All - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__allMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.All)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AnyMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AnyMethodToAggregationExpressionTranslator.cs index 5841d67f823..f89d84896e8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AnyMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AnyMethodToAggregationExpressionTranslator.cs @@ -24,19 +24,6 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AnyMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __anyMethods = - { - EnumerableMethod.Any, - QueryableMethod.Any - }; - - private static readonly MethodInfo[] __anyWithPredicateMethods = - { - EnumerableMethod.AnyWithPredicate, - QueryableMethod.AnyWithPredicate, - ArrayMethod.Exists - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; @@ -46,13 +33,13 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); - if (method.IsOneOf(__anyMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.Any)) { var ast = AstExpression.Gt(AstExpression.Size(sourceTranslation.Ast), 0); return new TranslatedExpression(expression, ast, new BooleanSerializer()); } - if (method.IsOneOf(__anyWithPredicateMethods) || ListMethod.IsExistsMethod(method)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AnyWithPredicate) || method.Is(ArrayMethod.Exists) || ListMethod.IsExistsMethod(method)) { var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, method.IsStatic ? arguments[1] : arguments[0]); var predicateParameter = predicateLambda.Parameters[0]; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs index 5a7a2942f70..ecee001f1c8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs @@ -24,26 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AppendOrPrependMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __appendOrPrependMethods = - { - EnumerableMethod.Append, - EnumerableMethod.Prepend, - QueryableMethod.Append, - QueryableMethod.Prepend - }; - - private static readonly MethodInfo[] __appendMethods = - { - EnumerableMethod.Append, - QueryableMethod.Append - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__appendOrPrependMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AppendOrPrepend)) { var sourceExpression = arguments[0]; var elementExpression = arguments[1]; @@ -68,7 +54,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } } - var ast = method.IsOneOf(__appendMethods) ? + var ast = method.IsOneOf(EnumerableOrQueryableMethod.Append) ? AstExpression.ConcatArrays(sourceTranslation.Ast, AstExpression.ComputedArray(elementTranslation.Ast)) : AstExpression.ConcatArrays(AstExpression.ComputedArray(elementTranslation.Ast), sourceTranslation.Ast); var serializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AverageMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AverageMethodToAggregationExpressionTranslator.cs index f2849ef812c..35b0f07e624 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AverageMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AverageMethodToAggregationExpressionTranslator.cs @@ -26,87 +26,19 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AverageMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __averageMethods = - { - EnumerableMethod.AverageDecimal, - EnumerableMethod.AverageDecimalWithSelector, - EnumerableMethod.AverageDouble, - EnumerableMethod.AverageDoubleWithSelector, - EnumerableMethod.AverageInt32, - EnumerableMethod.AverageInt32WithSelector, - EnumerableMethod.AverageInt64, - EnumerableMethod.AverageInt64WithSelector, - EnumerableMethod.AverageNullableDecimal, - EnumerableMethod.AverageNullableDecimalWithSelector, - EnumerableMethod.AverageNullableDouble, - EnumerableMethod.AverageNullableDoubleWithSelector, - EnumerableMethod.AverageNullableInt32, - EnumerableMethod.AverageNullableInt32WithSelector, - EnumerableMethod.AverageNullableInt64, - EnumerableMethod.AverageNullableInt64WithSelector, - EnumerableMethod.AverageNullableSingle, - EnumerableMethod.AverageNullableSingleWithSelector, - EnumerableMethod.AverageSingle, - EnumerableMethod.AverageSingleWithSelector, - QueryableMethod.AverageDecimal, - QueryableMethod.AverageDecimalWithSelector, - QueryableMethod.AverageDouble, - QueryableMethod.AverageDoubleWithSelector, - QueryableMethod.AverageInt32, - QueryableMethod.AverageInt32WithSelector, - QueryableMethod.AverageInt64, - QueryableMethod.AverageInt64WithSelector, - QueryableMethod.AverageNullableDecimal, - QueryableMethod.AverageNullableDecimalWithSelector, - QueryableMethod.AverageNullableDouble, - QueryableMethod.AverageNullableDoubleWithSelector, - QueryableMethod.AverageNullableInt32, - QueryableMethod.AverageNullableInt32WithSelector, - QueryableMethod.AverageNullableInt64, - QueryableMethod.AverageNullableInt64WithSelector, - QueryableMethod.AverageNullableSingle, - QueryableMethod.AverageNullableSingleWithSelector, - QueryableMethod.AverageSingle, - QueryableMethod.AverageSingleWithSelector - }; - - private static readonly MethodInfo[] __averageWithSelectorMethods = - { - EnumerableMethod.AverageDecimalWithSelector, - EnumerableMethod.AverageDoubleWithSelector, - EnumerableMethod.AverageInt32WithSelector, - EnumerableMethod.AverageInt64WithSelector, - EnumerableMethod.AverageNullableDecimalWithSelector, - EnumerableMethod.AverageNullableDoubleWithSelector, - EnumerableMethod.AverageNullableInt32WithSelector, - EnumerableMethod.AverageNullableInt64WithSelector, - EnumerableMethod.AverageNullableSingleWithSelector, - EnumerableMethod.AverageSingleWithSelector, - QueryableMethod.AverageDecimalWithSelector, - QueryableMethod.AverageDoubleWithSelector, - QueryableMethod.AverageInt32WithSelector, - QueryableMethod.AverageInt64WithSelector, - QueryableMethod.AverageNullableDecimalWithSelector, - QueryableMethod.AverageNullableDoubleWithSelector, - QueryableMethod.AverageNullableInt32WithSelector, - QueryableMethod.AverageNullableInt64WithSelector, - QueryableMethod.AverageNullableSingleWithSelector, - QueryableMethod.AverageSingleWithSelector - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__averageMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AverageOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); AstExpression ast; - if (method.IsOneOf(__averageWithSelectorMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AverageWithSelectorOverloads)) { var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); var selectorParameter = selectorLambda.Parameters[0]; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareMethodToAggregationExpressionTranslator.cs index 79f68c311ce..91aedb65465 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareMethodToAggregationExpressionTranslator.cs @@ -24,18 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class CompareMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __stringCompareMethods = - [ - StringMethod.StaticCompare, - StringMethod.StaticCompareWithIgnoreCase - ]; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsStaticCompareMethod() || method.IsInstanceCompareToMethod() || method.IsOneOf(__stringCompareMethods)) + if (method.IsStaticCompareMethod() || method.IsInstanceCompareToMethod() || method.IsOneOf(StringMethod.CompareOverloads)) { Expression value1Expression; Expression value2Expression; @@ -54,7 +48,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var value2Translation = ExpressionToAggregationExpressionTranslator.Translate(context, value2Expression); AstExpression ast; - if (method.Is(StringMethod.StaticCompareWithIgnoreCase)) + if (method.Is(StringMethod.CompareWithIgnoreCase)) { var ignoreCaseExpression = arguments[2]; var ignoreCase = ignoreCaseExpression.GetConstantValue(containingExpression: expression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsMethodToAggregationExpressionTranslator.cs index 7a4d64a3ff0..c6611ed9044 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsMethodToAggregationExpressionTranslator.cs @@ -31,7 +31,8 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC return StartsWithContainsOrEndsWithMethodToAggregationExpressionTranslator.Translate(context, expression); } - if (IsEnumerableContainsMethod(expression, out var sourceExpression, out var valueExpression)) + if (EnumerableMethod.IsContainsMethod(expression, out var sourceExpression, out var valueExpression) && + !expression.Method.Is(StringMethod.ContainsWithChar)) { return TranslateEnumerableContains(context, expression, sourceExpression, valueExpression); } @@ -40,39 +41,6 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } // private methods - private static bool IsEnumerableContainsMethod(MethodCallExpression expression, out Expression sourceExpression, out Expression valueExpression) - { - var method = expression.Method; - var arguments = expression.Arguments; - - if (method.IsOneOf(EnumerableMethod.Contains, QueryableMethod.Contains)) - { - sourceExpression = arguments[0]; - valueExpression = arguments[1]; - return true; - } - - if (!method.IsStatic && method.ReturnType == typeof(bool) && method.Name == "Contains" && arguments.Count == 1) - { - sourceExpression = expression.Object; - valueExpression = arguments[0]; - - if (sourceExpression.Type.TryGetIEnumerableGenericInterface(out var ienumerableInterface)) - { - var itemType = ienumerableInterface.GetGenericArguments()[0]; - if (itemType == valueExpression.Type) - { - // string.Contains(char) is not translated like other Contains methods because string is not represented as an array - return sourceExpression.Type != typeof(string) && valueExpression.Type != typeof(char); - } - } - } - - sourceExpression = null; - valueExpression = null; - return false; - } - private static TranslatedExpression TranslateEnumerableContains(TranslationContext context, MethodCallExpression expression, Expression sourceExpression, Expression valueExpression) { var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConvertMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConvertMethodToAggregationExpressionTranslator.cs index 9f6844b3031..d4283b67f23 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConvertMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConvertMethodToAggregationExpressionTranslator.cs @@ -42,8 +42,9 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var valueExpression = arguments[0]; var optionsExpression = arguments[1]; - var (toBsonType, toSerializer) = TranslateToType(expression, toType); + var toBsonType = GetResultRepresentation(expression, toType); var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); + var toSerializer = context.NodeSerializers.GetSerializer(expression); var (subType, byteOrder, format, onErrorAst, onNullAst) = TranslateOptions(context, expression, optionsExpression, toSerializer); var ast = AstExpression.Convert(valueTranslation.Ast, toBsonType.Render(), subType, byteOrder, format, onErrorAst, onNullAst); @@ -143,39 +144,39 @@ IBsonSerializer toSerializer return (subType, byteOrder, format, onErrorTranslation?.Ast, onNullTranslation?.Ast); } - private static (BsonType ToBsonType, IBsonSerializer ToSerializer) TranslateToType(Expression expression, Type toType) + private static BsonType GetResultRepresentation(Expression expression, Type toType) { var isNullable = toType.IsNullable(); var valueType = isNullable ? Nullable.GetUnderlyingType(toType) : toType; - var (bsonType, valueSerializer) = (ValueTuple)(Type.GetTypeCode(valueType) switch + var representation = Type.GetTypeCode(valueType) switch { - TypeCode.Boolean => (BsonType.Boolean, BooleanSerializer.Instance), - TypeCode.Byte => (BsonType.Int32, ByteSerializer.Instance), - TypeCode.Char => (BsonType.String, StringSerializer.Instance), - TypeCode.DateTime => (BsonType.DateTime, DateTimeSerializer.Instance), - TypeCode.Decimal => (BsonType.Decimal128, DecimalSerializer.Instance), - TypeCode.Double => (BsonType.Double, DoubleSerializer.Instance), - TypeCode.Int16 => (BsonType.Int32, Int16Serializer.Instance), - TypeCode.Int32 => (BsonType.Int32, Int32Serializer.Instance), - TypeCode.Int64 => (BsonType.Int64, Int64Serializer.Instance), - TypeCode.SByte => (BsonType.Int32, SByteSerializer.Instance), - TypeCode.Single => (BsonType.Double, SingleSerializer.Instance), - TypeCode.String => (BsonType.String, StringSerializer.Instance), - TypeCode.UInt16 => (BsonType.Int32, UInt16Serializer.Instance), - TypeCode.UInt32 => (BsonType.Int64, Int32Serializer.Instance), - TypeCode.UInt64 => (BsonType.Decimal128, UInt64Serializer.Instance), - - _ when valueType == typeof(byte[]) => (BsonType.Binary, ByteArraySerializer.Instance), - _ when valueType == typeof(BsonBinaryData) => (BsonType.Binary, BsonBinaryDataSerializer.Instance), - _ when valueType == typeof(Decimal128) => (BsonType.Decimal128, Decimal128Serializer.Instance), - _ when valueType == typeof(Guid) => (BsonType.Binary, GuidSerializer.StandardInstance), - _ when valueType == typeof(ObjectId) => (BsonType.ObjectId, ObjectIdSerializer.Instance), + TypeCode.Boolean => BsonType.Boolean, + TypeCode.Byte => BsonType.Int32, + TypeCode.Char => BsonType.String, + TypeCode.DateTime => BsonType.DateTime, + TypeCode.Decimal => BsonType.Decimal128, + TypeCode.Double => BsonType.Double, + TypeCode.Int16 => BsonType.Int32, + TypeCode.Int32 => BsonType.Int32, + TypeCode.Int64 => BsonType.Int64, + TypeCode.SByte => BsonType.Int32, + TypeCode.Single => BsonType.Double, + TypeCode.String => BsonType.String, + TypeCode.UInt16 => BsonType.Int32, + TypeCode.UInt32 => BsonType.Int64, + TypeCode.UInt64 => BsonType.Decimal128, + + _ when valueType == typeof(byte[]) => BsonType.Binary, + _ when valueType == typeof(BsonBinaryData) => BsonType.Binary, + _ when valueType == typeof(Decimal128) => BsonType.Decimal128, + _ when valueType == typeof(Guid) => BsonType.Binary, + _ when valueType == typeof(ObjectId) => BsonType.ObjectId, _ => throw new ExpressionNotSupportedException(expression, because: $"{toType} is not a valid TTo for Convert") - }); + }; - return (bsonType, isNullable ? NullableSerializer.Create(valueSerializer) : valueSerializer); + return representation; } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CountMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CountMethodToAggregationExpressionTranslator.cs index 73801af66f7..462a0655a44 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CountMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CountMethodToAggregationExpressionTranslator.cs @@ -27,45 +27,19 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class CountMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __countMethods; - private static readonly MethodInfo[] __countWithPredicateMethods; - - static CountMethodToAggregationExpressionTranslator() - { - __countMethods = new[] - { - EnumerableMethod.Count, - EnumerableMethod.CountWithPredicate, - EnumerableMethod.LongCount, - EnumerableMethod.LongCountWithPredicate, - QueryableMethod.Count, - QueryableMethod.CountWithPredicate, - QueryableMethod.LongCount, - QueryableMethod.LongCountWithPredicate - }; - - __countWithPredicateMethods = new[] - { - EnumerableMethod.CountWithPredicate, - EnumerableMethod.LongCountWithPredicate, - QueryableMethod.CountWithPredicate, - QueryableMethod.LongCountWithPredicate - }; - } - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__countMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.CountOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); AstExpression ast; - if (method.IsOneOf(__countWithPredicateMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.CountWithPredicateOverloads)) { if (sourceExpression.Type == typeof(string)) { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateFromStringMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateFromStringMethodToAggregationExpressionTranslator.cs index 49e5c99a641..beacd741ab4 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateFromStringMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateFromStringMethodToAggregationExpressionTranslator.cs @@ -25,33 +25,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class DateFromStringMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __dateFromStringMethods = - { - MqlMethod.DateFromString, - MqlMethod.DateFromStringWithFormat, - MqlMethod.DateFromStringWithFormatAndTimezone, - MqlMethod.DateFromStringWithFormatAndTimezoneAndOnErrorAndOnNull - }; - - private static readonly MethodInfo[] __withFormatMethods = - { - MqlMethod.DateFromStringWithFormat, - MqlMethod.DateFromStringWithFormatAndTimezone, - MqlMethod.DateFromStringWithFormatAndTimezoneAndOnErrorAndOnNull - }; - - private static readonly MethodInfo[] __withTimezoneMethods = - { - MqlMethod.DateFromStringWithFormatAndTimezone, - MqlMethod.DateFromStringWithFormatAndTimezoneAndOnErrorAndOnNull - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__dateFromStringMethods)) + if (method.IsOneOf(MqlMethod.DateFromStringOverloads)) { var dateStringExpression = arguments[0]; var dateStringTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, dateStringExpression); @@ -59,7 +38,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC IBsonSerializer resultSerializer = DateTimeSerializer.Instance; AstExpression format = null; - if (method.IsOneOf(__withFormatMethods)) + if (method.IsOneOf(MqlMethod.DateFromStringWithFormatOverloads)) { var formatExpression = arguments[1]; var formatTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, formatExpression); @@ -67,7 +46,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } AstExpression timezoneAst = null; - if (method.IsOneOf(__withTimezoneMethods)) + if (method.IsOneOf(MqlMethod.DateFromStringWithTimezoneOverloads)) { var timezoneExpression = arguments[2]; var timezoneTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, timezoneExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeAddOrSubtractMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeAddOrSubtractMethodToAggregationExpressionTranslator.cs index e3448e36864..993477d0dfb 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeAddOrSubtractMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeAddOrSubtractMethodToAggregationExpressionTranslator.cs @@ -29,81 +29,9 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class DateTimeAddOrSubtractMethodToAggregationExpressionTranslator { - private static MethodInfo[] __dateTimeAddOrSubtractMethods = new[] - { - DateTimeMethod.Add, - DateTimeMethod.AddDays, - DateTimeMethod.AddDaysWithTimezone, - DateTimeMethod.AddHours, - DateTimeMethod.AddHoursWithTimezone, - DateTimeMethod.AddMilliseconds, - DateTimeMethod.AddMillisecondsWithTimezone, - DateTimeMethod.AddMinutes, - DateTimeMethod.AddMinutesWithTimezone, - DateTimeMethod.AddMonths, - DateTimeMethod.AddMonthsWithTimezone, - DateTimeMethod.AddQuarters, - DateTimeMethod.AddQuartersWithTimezone, - DateTimeMethod.AddSeconds, - DateTimeMethod.AddSecondsWithTimezone, - DateTimeMethod.AddTicks, - DateTimeMethod.AddWeeks, - DateTimeMethod.AddWeeksWithTimezone, - DateTimeMethod.AddWithTimezone, - DateTimeMethod.AddWithUnit, - DateTimeMethod.AddWithUnitAndTimezone, - DateTimeMethod.AddYears, - DateTimeMethod.AddYearsWithTimezone, - DateTimeMethod.SubtractWithTimeSpan, - DateTimeMethod.SubtractWithTimeSpanAndTimezone, - DateTimeMethod.SubtractWithUnit, - DateTimeMethod.SubtractWithUnitAndTimezone - }; - - private static MethodInfo[] __dateTimeAddOrSubtractWithTimeSpanMethods = new[] - { - DateTimeMethod.Add, - DateTimeMethod.AddWithTimezone, - DateTimeMethod.SubtractWithTimeSpan, - DateTimeMethod.SubtractWithTimeSpanAndTimezone - }; - - private static MethodInfo[] __dateTimeAddOrSubtractWithUnitMethods = new[] - { - DateTimeMethod.AddWithUnit, - DateTimeMethod.AddWithUnitAndTimezone, - DateTimeMethod.SubtractWithUnit, - DateTimeMethod.SubtractWithUnitAndTimezone - }; - - private static MethodInfo[] __dateTimeAddOrSubtractWithTimezoneMethods = new[] - { - DateTimeMethod.AddDaysWithTimezone, - DateTimeMethod.AddHoursWithTimezone, - DateTimeMethod.AddMillisecondsWithTimezone, - DateTimeMethod.AddMinutesWithTimezone, - DateTimeMethod.AddMonthsWithTimezone, - DateTimeMethod.AddQuartersWithTimezone, - DateTimeMethod.AddSecondsWithTimezone, - DateTimeMethod.AddWeeksWithTimezone, - DateTimeMethod.AddWithTimezone, - DateTimeMethod.AddWithUnitAndTimezone, - DateTimeMethod.AddYearsWithTimezone, - DateTimeMethod.SubtractWithTimeSpanAndTimezone, - DateTimeMethod.SubtractWithUnitAndTimezone - }; - - private static MethodInfo[] __dateTimeSubtractMethods = new[] - { - DateTimeMethod.SubtractWithTimeSpan, - DateTimeMethod.SubtractWithTimeSpanAndTimezone, - DateTimeMethod.SubtractWithUnit, - DateTimeMethod.SubtractWithUnitAndTimezone - }; - public static bool CanTranslate(MethodCallExpression expression) { - return expression.Method.IsOneOf(__dateTimeAddOrSubtractMethods); + return expression.Method.IsOneOf(DateTimeMethod.AddOrSubtractOverloads); } public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) @@ -111,7 +39,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__dateTimeAddOrSubtractMethods)) + if (method.IsOneOf(DateTimeMethod.AddOrSubtractOverloads)) { Expression thisExpression, valueExpression; if (method.IsStatic) @@ -128,7 +56,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var thisTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, thisExpression); AstExpression unit, amount; - if (method.IsOneOf(__dateTimeAddOrSubtractWithTimeSpanMethods)) + if (method.IsOneOf(DateTimeMethod.AddOrSubtractWithTimeSpanOverloads)) { if (valueExpression is ConstantExpression constantValueExpression) { @@ -161,7 +89,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC }; } } - else if (method.IsOneOf(__dateTimeAddOrSubtractWithUnitMethods)) + else if (method.IsOneOf(DateTimeMethod.AddOrSubtractWithUnitOverloads)) { var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); var valueAst = ConvertHelper.RemoveWideningConvert(valueTranslation); @@ -192,14 +120,14 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } AstExpression timezone = null; - if (method.IsOneOf(__dateTimeAddOrSubtractWithTimezoneMethods)) + if (method.IsOneOf(DateTimeMethod.AddOrSubtractWithTimezoneOverloads)) { var timezoneExpression = arguments.Last(); var timezoneTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, timezoneExpression); timezone = timezoneTranslation.Ast; } - var ast = method.IsOneOf(__dateTimeSubtractMethods) ? + var ast = method.IsOneOf(DateTimeMethod.SubtractReturningDateTimeOverloads) ? AstExpression.DateSubtract(thisTranslation.Ast, unit, amount, timezone) : AstExpression.DateAdd(thisTranslation.Ast, unit, amount, timezone); var serializer = DateTimeSerializer.UtcInstance; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DistinctMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DistinctMethodToAggregationExpressionTranslator.cs index 0486cca5e0b..7b573c4a4b1 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DistinctMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DistinctMethodToAggregationExpressionTranslator.cs @@ -24,18 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class DistinctMethodToAggregationExpressionTranslator { - private readonly static MethodInfo[] __distinctMethods = - { - EnumerableMethod.Distinct, - QueryableMethod.Distinct - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__distinctMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.Distinct)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs index e6e9cf24e1d..d9bc8509335 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs @@ -23,26 +23,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ElementAtMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __elementAtMethods = - { - EnumerableMethod.ElementAt, - EnumerableMethod.ElementAtOrDefault, - QueryableMethod.ElementAt, - QueryableMethod.ElementAtOrDefault - }; - - private static readonly MethodInfo[] __elementAtOrDefaultMethods = - { - EnumerableMethod.ElementAtOrDefault, - QueryableMethod.ElementAtOrDefault - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__elementAtMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.ElementAtOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); @@ -53,7 +39,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var indexTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, indexExpression); AstExpression ast; - if (method.IsOneOf(__elementAtOrDefaultMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.ElementAtOrDefault)) { var defaultValue = itemSerializer.ValueType.GetDefaultValue(); var serializedDefaultValue = SerializationHelper.SerializeValue(itemSerializer, defaultValue); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/EnumerableConcatMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/EnumerableConcatMethodToAggregationExpressionTranslator.cs index a42526b03ef..45f22d9c09a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/EnumerableConcatMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/EnumerableConcatMethodToAggregationExpressionTranslator.cs @@ -24,21 +24,15 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class EnumerableConcatMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __concatMethods = - { - EnumerableMethod.Concat, - QueryableMethod.Concat - }; - public static bool CanTranslate(MethodCallExpression expression) - => expression.Method.IsOneOf(__concatMethods); + => expression.Method.IsOneOf(EnumerableOrQueryableMethod.Concat); public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__concatMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.Concat)) { var firstExpression = arguments[0]; var firstTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, firstExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ExceptMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ExceptMethodToAggregationExpressionTranslator.cs index a539d5e750d..ddc528f6d5c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ExceptMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ExceptMethodToAggregationExpressionTranslator.cs @@ -24,18 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ExceptMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __exceptMethods = - { - EnumerableMethod.Except, - QueryableMethod.Except - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__exceptMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.Except)) { var firstExpression = arguments[0]; var firstTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, firstExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs index 96199df8207..811a2ab4346 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs @@ -25,68 +25,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class FirstOrLastMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __firstOrLastMethods = - { - EnumerableMethod.First, - EnumerableMethod.FirstWithPredicate, - EnumerableMethod.FirstOrDefault, - EnumerableMethod.FirstOrDefaultWithPredicate, - EnumerableMethod.Last, - EnumerableMethod.LastWithPredicate, - EnumerableMethod.LastOrDefault, - EnumerableMethod.LastOrDefaultWithPredicate, - QueryableMethod.First, - QueryableMethod.FirstWithPredicate, - QueryableMethod.FirstOrDefault, - QueryableMethod.FirstOrDefaultWithPredicate, - QueryableMethod.Last, - QueryableMethod.LastWithPredicate, - QueryableMethod.LastOrDefault, - QueryableMethod.LastOrDefaultWithPredicate - }; - - private static readonly MethodInfo[] __firstMethods = - { - EnumerableMethod.First, - EnumerableMethod.FirstWithPredicate, - EnumerableMethod.FirstOrDefault, - EnumerableMethod.FirstOrDefaultWithPredicate, - QueryableMethod.First, - QueryableMethod.FirstWithPredicate, - QueryableMethod.FirstOrDefault, - QueryableMethod.FirstOrDefaultWithPredicate - }; - - private static readonly MethodInfo[] __orDefaultMethods = - { - EnumerableMethod.FirstOrDefault, - EnumerableMethod.FirstOrDefaultWithPredicate, - EnumerableMethod.LastOrDefault, - EnumerableMethod.LastOrDefaultWithPredicate, - QueryableMethod.FirstOrDefault, - QueryableMethod.FirstOrDefaultWithPredicate, - QueryableMethod.LastOrDefault, - QueryableMethod.LastOrDefaultWithPredicate - }; - - private static readonly MethodInfo[] __withPredicateMethods = - { - EnumerableMethod.FirstWithPredicate, - EnumerableMethod.FirstOrDefaultWithPredicate, - EnumerableMethod.LastWithPredicate, - EnumerableMethod.LastOrDefaultWithPredicate, - QueryableMethod.FirstWithPredicate, - QueryableMethod.FirstOrDefaultWithPredicate, - QueryableMethod.LastWithPredicate, - QueryableMethod.LastOrDefaultWithPredicate - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__firstOrLastMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.FirstOrLastOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); @@ -95,9 +39,9 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var sourceAst = sourceTranslation.Ast; var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - var isFirstMethod = method.IsOneOf(__firstMethods); + var isFirstMethod = method.IsOneOf(EnumerableOrQueryableMethod.FirstOverloads); - if (method.IsOneOf(__withPredicateMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.FirstOrLastWithPredicateOverloads)) { var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); var parameterExpression = predicateLambda.Parameters.Single(); @@ -122,7 +66,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } AstExpression ast; - if (method.IsOneOf(__orDefaultMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.FirstOrDefaultOverloads, EnumerableOrQueryableMethod.LastOrDefaultOverloads)) { var defaultValue = itemSerializer.ValueType.GetDefaultValue(); var serializedDefaultValue = SerializationHelper.SerializeValue(itemSerializer, defaultValue); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfMethodToAggregationExpressionTranslator.cs index 8fe5f57f89a..e6392e3c8af 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfMethodToAggregationExpressionTranslator.cs @@ -27,56 +27,6 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class IndexOfMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __indexOfMethods = - { - StringMethod.IndexOfWithChar, - StringMethod.IndexOfBytesWithValue, - StringMethod.IndexOfBytesWithValueAndStartIndex, - StringMethod.IndexOfBytesWithValueAndStartIndexAndCount, - StringMethod.IndexOfWithCharAndStartIndex, - StringMethod.IndexOfWithCharAndStartIndexAndCount, - StringMethod.IndexOfWithString, - StringMethod.IndexOfWithStringAndStartIndex, - StringMethod.IndexOfWithStringAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - private static readonly MethodInfo[] __indexOfWithStartIndexMethods = - { - StringMethod.IndexOfBytesWithValueAndStartIndex, - StringMethod.IndexOfBytesWithValueAndStartIndexAndCount, - StringMethod.IndexOfWithCharAndStartIndex, - StringMethod.IndexOfWithCharAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndex, - StringMethod.IndexOfWithStringAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndexAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - private static readonly MethodInfo[] __indexOfWithCountMethods = - { - StringMethod.IndexOfBytesWithValueAndStartIndexAndCount, - StringMethod.IndexOfWithCharAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - private static readonly MethodInfo[] __indexOfWithStringComparisonMethods = - { - StringMethod.IndexOfWithStringAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - private static readonly MethodInfo[] __indexOfBytesMethods = - { - StringMethod.IndexOfBytesWithValue, - StringMethod.IndexOfBytesWithValueAndStartIndex, - StringMethod.IndexOfBytesWithValueAndStartIndexAndCount - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { if (IsStringIndexOfMethod(expression, out var objectExpression, out var valueExpression, out var startIndexExpression, out var countExpression, out var comparisonTypeExpression)) @@ -100,7 +50,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var endAst = CreateEndAst(startIndexTranslation?.Ast, countTranslation?.Ast); AstExpression ast; - if (expression.Method.IsOneOf(__indexOfBytesMethods) || ordinal) + if (expression.Method.IsOneOf(StringMethod.IndexOfBytesOverloads) || ordinal) { ast = AstExpression.IndexOfBytes(objectTranslation.Ast, valueTranslation.Ast, startIndexTranslation?.Ast, endAst); } @@ -167,14 +117,14 @@ private static bool IsStringIndexOfMethod( var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__indexOfMethods)) + if (method.IsOneOf(StringMethod.IndexOfOverloads)) { - if (method.IsOneOf(__indexOfBytesMethods)) + if (method.IsOneOf(StringMethod.IndexOfBytesOverloads)) { instanceExpression = arguments[0]; valueExpression = arguments[1]; - startIndexExpression = method.IsOneOf(__indexOfWithStartIndexMethods) ? arguments[2] : null; - countExpression = method.IsOneOf(__indexOfWithCountMethods) ? arguments[3] : null; + startIndexExpression = method.IsOneOf(StringMethod.IndexOfWithStartIndexOverloads) ? arguments[2] : null; + countExpression = method.IsOneOf(StringMethod.IndexOfWithCountOverloads) ? arguments[3] : null; comparisonTypeExpression = null; return true; } @@ -182,9 +132,9 @@ private static bool IsStringIndexOfMethod( { instanceExpression = expression.Object; valueExpression = arguments[0]; - startIndexExpression = method.IsOneOf(__indexOfWithStartIndexMethods) ? arguments[1] : null; - countExpression = method.IsOneOf(__indexOfWithCountMethods) ? arguments[2] : null; - comparisonTypeExpression = method.IsOneOf(__indexOfWithStringComparisonMethods) ? arguments.Last() : null; + startIndexExpression = method.IsOneOf(StringMethod.IndexOfWithStartIndexOverloads) ? arguments[1] : null; + countExpression = method.IsOneOf(StringMethod.IndexOfWithCountOverloads) ? arguments[2] : null; + comparisonTypeExpression = method.IsOneOf(StringMethod.IndexOfWithStringComparisonOverloads) ? arguments.Last() : null; return true; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IntersectMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IntersectMethodToAggregationExpressionTranslator.cs index c5519f5547d..82086889cda 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IntersectMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IntersectMethodToAggregationExpressionTranslator.cs @@ -27,7 +27,7 @@ internal static class IntersectMethodToAggregationExpressionTranslator private static readonly MethodInfo[] __intersectMethods = { EnumerableMethod.Intersect, - QueryableMethod.Interset + QueryableMethod.Intersect }; public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs index 0baa8709c1d..7572c2dd83f 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs @@ -24,50 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal class MedianMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __medianMethods = - [ - MongoEnumerableMethod.MedianDecimal, - MongoEnumerableMethod.MedianDecimalWithSelector, - MongoEnumerableMethod.MedianDouble, - MongoEnumerableMethod.MedianDoubleWithSelector, - MongoEnumerableMethod.MedianInt32, - MongoEnumerableMethod.MedianInt32WithSelector, - MongoEnumerableMethod.MedianInt64, - MongoEnumerableMethod.MedianInt64WithSelector, - MongoEnumerableMethod.MedianNullableDecimal, - MongoEnumerableMethod.MedianNullableDecimalWithSelector, - MongoEnumerableMethod.MedianNullableDouble, - MongoEnumerableMethod.MedianNullableDoubleWithSelector, - MongoEnumerableMethod.MedianNullableInt32, - MongoEnumerableMethod.MedianNullableInt32WithSelector, - MongoEnumerableMethod.MedianNullableInt64, - MongoEnumerableMethod.MedianNullableInt64WithSelector, - MongoEnumerableMethod.MedianNullableSingle, - MongoEnumerableMethod.MedianNullableSingleWithSelector, - MongoEnumerableMethod.MedianSingle, - MongoEnumerableMethod.MedianSingleWithSelector - ]; - - private static readonly MethodInfo[] __medianWithSelectorMethods = - [ - MongoEnumerableMethod.MedianDecimalWithSelector, - MongoEnumerableMethod.MedianDoubleWithSelector, - MongoEnumerableMethod.MedianInt32WithSelector, - MongoEnumerableMethod.MedianInt64WithSelector, - MongoEnumerableMethod.MedianNullableDecimalWithSelector, - MongoEnumerableMethod.MedianNullableDoubleWithSelector, - MongoEnumerableMethod.MedianNullableInt32WithSelector, - MongoEnumerableMethod.MedianNullableInt64WithSelector, - MongoEnumerableMethod.MedianNullableSingleWithSelector, - MongoEnumerableMethod.MedianSingleWithSelector - ]; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__medianMethods)) + if (method.IsOneOf(MongoEnumerableMethod.MedianOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); @@ -75,7 +37,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var inputAst = sourceTranslation.Ast; - if (method.IsOneOf(__medianWithSelectorMethods)) + if (method.IsOneOf(MongoEnumerableMethod.MedianWithSelectorOverloads)) { var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); @@ -104,4 +66,4 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC throw new ExpressionNotSupportedException(expression); } } -} \ No newline at end of file +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs index 216d89f1c49..005b5e7b80d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs @@ -24,50 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal class PercentileMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __percentileMethods = - [ - MongoEnumerableMethod.PercentileDecimal, - MongoEnumerableMethod.PercentileDecimalWithSelector, - MongoEnumerableMethod.PercentileDouble, - MongoEnumerableMethod.PercentileDoubleWithSelector, - MongoEnumerableMethod.PercentileInt32, - MongoEnumerableMethod.PercentileInt32WithSelector, - MongoEnumerableMethod.PercentileInt64, - MongoEnumerableMethod.PercentileInt64WithSelector, - MongoEnumerableMethod.PercentileNullableDecimal, - MongoEnumerableMethod.PercentileNullableDecimalWithSelector, - MongoEnumerableMethod.PercentileNullableDouble, - MongoEnumerableMethod.PercentileNullableDoubleWithSelector, - MongoEnumerableMethod.PercentileNullableInt32, - MongoEnumerableMethod.PercentileNullableInt32WithSelector, - MongoEnumerableMethod.PercentileNullableInt64, - MongoEnumerableMethod.PercentileNullableInt64WithSelector, - MongoEnumerableMethod.PercentileNullableSingle, - MongoEnumerableMethod.PercentileNullableSingleWithSelector, - MongoEnumerableMethod.PercentileSingle, - MongoEnumerableMethod.PercentileSingleWithSelector - ]; - - private static readonly MethodInfo[] __percentileWithSelectorMethods = - [ - MongoEnumerableMethod.PercentileDecimalWithSelector, - MongoEnumerableMethod.PercentileDoubleWithSelector, - MongoEnumerableMethod.PercentileInt32WithSelector, - MongoEnumerableMethod.PercentileInt64WithSelector, - MongoEnumerableMethod.PercentileNullableDecimalWithSelector, - MongoEnumerableMethod.PercentileNullableDoubleWithSelector, - MongoEnumerableMethod.PercentileNullableInt32WithSelector, - MongoEnumerableMethod.PercentileNullableInt64WithSelector, - MongoEnumerableMethod.PercentileNullableSingleWithSelector, - MongoEnumerableMethod.PercentileSingleWithSelector - ]; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__percentileMethods)) + if (method.IsOneOf(MongoEnumerableMethod.PercentileOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); @@ -75,7 +37,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var inputAst = sourceTranslation.Ast; - if (method.IsOneOf(__percentileWithSelectorMethods)) + if (method.IsOneOf(MongoEnumerableMethod.PercentileWithSelectorOverloads)) { var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); @@ -107,4 +69,4 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC throw new ExpressionNotSupportedException(expression); } } -} \ No newline at end of file +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs index 89b67968c24..1a020ceeb74 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs @@ -25,18 +25,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class SelectManyMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __selectManyMethods = - { - EnumerableMethod.SelectMany, - QueryableMethod.SelectMany - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__selectManyMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.SelectManyWithSelector)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SetEqualsMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SetEqualsMethodToAggregationExpressionTranslator.cs index 3cacd4c462c..51679bddf22 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SetEqualsMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SetEqualsMethodToAggregationExpressionTranslator.cs @@ -17,6 +17,7 @@ using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { @@ -24,8 +25,11 @@ internal static class SetEqualsMethodToAggregationExpressionTranslator { public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { - if (IsSetEqualsMethod(expression, out var objectExpression, out var otherExpression)) + if (ISetMethod.IsSetEqualsMethod(expression.Method)) { + var objectExpression = expression.Object; + var otherExpression = expression.Arguments[0]; + var objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression); var otherTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, otherExpression); var ast = AstExpression.SetEquals(objectTranslation.Ast, otherTranslation.Ast); @@ -34,34 +38,5 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } throw new ExpressionNotSupportedException(expression); } - - private static bool IsSetEqualsMethod(MethodCallExpression expression, out Expression objectExpression, out Expression otherExpression) - { - var method = expression.Method; - var arguments = expression.Arguments; - - if (!method.IsStatic && - method.ReturnType == typeof(bool) && - method.Name == "SetEquals" && - arguments.Count == 1) - { - objectExpression = expression.Object; - otherExpression = arguments[0]; - if (objectExpression.Type.TryGetIEnumerableGenericInterface(out var objectEnumerableInterface) && - otherExpression.Type.TryGetIEnumerableGenericInterface(out var otherEnumerableInterface)) - { - var objectItemType = objectEnumerableInterface.GetGenericArguments()[0]; - var otherItemType = otherEnumerableInterface.GetGenericArguments()[0]; - if (objectItemType == otherItemType) - { - return true; - } - } - } - - objectExpression = null; - otherExpression = null; - return false; - } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToListMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToListMethodToAggregationExpressionTranslator.cs index f95a2361fdc..263bd9ac6a8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToListMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToListMethodToAggregationExpressionTranslator.cs @@ -20,6 +20,7 @@ using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { @@ -37,10 +38,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); var listItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - var listItemType = listItemSerializer.ValueType; - var listType = typeof(List<>).MakeGenericType(listItemType); - var listSerializerType = typeof(EnumerableInterfaceImplementerSerializer<,>).MakeGenericType(listType, listItemType); - var listSerializer = (IBsonSerializer)Activator.CreateInstance(listSerializerType, listItemSerializer); + var listSerializer = ListSerializer.Create(listItemSerializer); return new TranslatedExpression(expression, sourceTranslation.Ast, listSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs index 250c8658210..644cc097334 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs @@ -70,7 +70,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC @as: predicateSymbol.Var.Name, limitTranslation?.Ast); - var resultSerializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer); + var resultSerializer = context.NodeSerializers.GetSerializer(expression); return new TranslatedExpression(expression, ast, resultSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs index f45cffc3e49..d2272eb4876 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs @@ -29,128 +29,6 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class WindowMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __windowMethods = - { - WindowMethod.AddToSet, - WindowMethod.AverageWithDecimal, - WindowMethod.AverageWithDouble, - WindowMethod.AverageWithInt32, - WindowMethod.AverageWithInt64, - WindowMethod.AverageWithNullableDecimal, - WindowMethod.AverageWithNullableDouble, - WindowMethod.AverageWithNullableInt32, - WindowMethod.AverageWithNullableInt64, - WindowMethod.AverageWithNullableSingle, - WindowMethod.AverageWithSingle, - WindowMethod.Count, - WindowMethod.CovariancePopulationWithDecimals, - WindowMethod.CovariancePopulationWithDoubles, - WindowMethod.CovariancePopulationWithInt32s, - WindowMethod.CovariancePopulationWithInt64s, - WindowMethod.CovariancePopulationWithNullableDecimals, - WindowMethod.CovariancePopulationWithNullableDoubles, - WindowMethod.CovariancePopulationWithNullableInt32s, - WindowMethod.CovariancePopulationWithNullableInt64s, - WindowMethod.CovariancePopulationWithNullableSingles, - WindowMethod.CovariancePopulationWithSingles, - WindowMethod.CovarianceSampleWithDecimals, - WindowMethod.CovarianceSampleWithDoubles, - WindowMethod.CovarianceSampleWithInt32s, - WindowMethod.CovarianceSampleWithInt64s, - WindowMethod.CovarianceSampleWithNullableDecimals, - WindowMethod.CovarianceSampleWithNullableDoubles, - WindowMethod.CovarianceSampleWithNullableInt32s, - WindowMethod.CovarianceSampleWithNullableInt64s, - WindowMethod.CovarianceSampleWithNullableSingles, - WindowMethod.CovarianceSampleWithSingles, - WindowMethod.DenseRank, - WindowMethod.DerivativeWithDecimal, - WindowMethod.DerivativeWithDecimalAndUnit, - WindowMethod.DerivativeWithDouble, - WindowMethod.DerivativeWithDoubleAndUnit, - WindowMethod.DerivativeWithInt32, - WindowMethod.DerivativeWithInt32AndUnit, - WindowMethod.DerivativeWithInt64, - WindowMethod.DerivativeWithInt64AndUnit, - WindowMethod.DerivativeWithSingle, - WindowMethod.DerivativeWithSingleAndUnit, - WindowMethod.DocumentNumber, - WindowMethod.ExponentialMovingAverageWithDecimal, - WindowMethod.ExponentialMovingAverageWithDouble, - WindowMethod.ExponentialMovingAverageWithInt32, - WindowMethod.ExponentialMovingAverageWithInt64, - WindowMethod.ExponentialMovingAverageWithSingle, - WindowMethod.First, - WindowMethod.IntegralWithDecimal, - WindowMethod.IntegralWithDecimalAndUnit, - WindowMethod.IntegralWithDouble, - WindowMethod.IntegralWithDoubleAndUnit, - WindowMethod.IntegralWithInt32, - WindowMethod.IntegralWithInt32AndUnit, - WindowMethod.IntegralWithInt64, - WindowMethod.IntegralWithInt64AndUnit, - WindowMethod.IntegralWithSingle, - WindowMethod.IntegralWithSingleAndUnit, - WindowMethod.Last, - WindowMethod.Locf, - WindowMethod.Max, - WindowMethod.MedianWithDecimal, - WindowMethod.MedianWithDouble, - WindowMethod.MedianWithInt32, - WindowMethod.MedianWithInt64, - WindowMethod.MedianWithNullableDecimal, - WindowMethod.MedianWithNullableDouble, - WindowMethod.MedianWithNullableInt32, - WindowMethod.MedianWithNullableInt64, - WindowMethod.MedianWithNullableSingle, - WindowMethod.MedianWithSingle, - WindowMethod.Min, - WindowMethod.PercentileWithDecimal, - WindowMethod.PercentileWithDouble, - WindowMethod.PercentileWithInt32, - WindowMethod.PercentileWithInt64, - WindowMethod.PercentileWithNullableDecimal, - WindowMethod.PercentileWithNullableDouble, - WindowMethod.PercentileWithNullableInt32, - WindowMethod.PercentileWithNullableInt64, - WindowMethod.PercentileWithNullableSingle, - WindowMethod.PercentileWithSingle, - WindowMethod.Push, - WindowMethod.Rank, - WindowMethod.Shift, - WindowMethod.ShiftWithDefaultValue, - WindowMethod.StandardDeviationPopulationWithDecimal, - WindowMethod.StandardDeviationPopulationWithDouble, - WindowMethod.StandardDeviationPopulationWithInt32, - WindowMethod.StandardDeviationPopulationWithInt64, - WindowMethod.StandardDeviationPopulationWithNullableDecimal, - WindowMethod.StandardDeviationPopulationWithNullableDouble, - WindowMethod.StandardDeviationPopulationWithNullableInt32, - WindowMethod.StandardDeviationPopulationWithNullableInt64, - WindowMethod.StandardDeviationPopulationWithNullableSingle, - WindowMethod.StandardDeviationPopulationWithSingle, - WindowMethod.StandardDeviationSampleWithDecimal, - WindowMethod.StandardDeviationSampleWithDouble, - WindowMethod.StandardDeviationSampleWithInt32, - WindowMethod.StandardDeviationSampleWithInt64, - WindowMethod.StandardDeviationSampleWithNullableDecimal, - WindowMethod.StandardDeviationSampleWithNullableDouble, - WindowMethod.StandardDeviationSampleWithNullableInt32, - WindowMethod.StandardDeviationSampleWithNullableInt64, - WindowMethod.StandardDeviationSampleWithNullableSingle, - WindowMethod.StandardDeviationSampleWithSingle, - WindowMethod.SumWithDecimal, - WindowMethod.SumWithDouble, - WindowMethod.SumWithInt32, - WindowMethod.SumWithInt64, - WindowMethod.SumWithNullableDecimal, - WindowMethod.SumWithNullableDouble, - WindowMethod.SumWithNullableInt32, - WindowMethod.SumWithNullableInt64, - WindowMethod.SumWithNullableSingle, - WindowMethod.SumWithSingle - }; - private static readonly MethodInfo[] __nullaryMethods = { WindowMethod.Count, @@ -299,7 +177,7 @@ internal static class WindowMethodToAggregationExpressionTranslator public static bool CanTranslate(MethodCallExpression expression) { - return expression.Method.IsOneOf(__windowMethods); + return IsWindowMethod(expression.Method); } public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) @@ -308,7 +186,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var parameters = method.GetParameters(); var arguments = expression.Arguments.ToArray(); - if (method.IsOneOf(__windowMethods)) + if (IsWindowMethod(method)) { var partitionExpression = arguments[0]; var partitionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, partitionExpression); @@ -623,5 +501,10 @@ private static IBsonSerializer GetSortBySerializerGeneric( return renderedField.FieldSerializer; } + + private static bool IsWindowMethod(MethodInfo method) + { + return method.DeclaringType == typeof(ISetWindowFieldsPartitionExtensions); + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs index c5eba340536..fedb16e0893 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs @@ -19,6 +19,8 @@ using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -27,28 +29,14 @@ internal static class NewArrayInitExpressionToAggregationExpressionTranslator public static TranslatedExpression Translate(TranslationContext context, NewArrayExpression expression) { var items = new List(); - IBsonSerializer itemSerializer = null; foreach (var itemExpression in expression.Expressions) { var itemTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, itemExpression); items.Add(itemTranslation.Ast); - itemSerializer ??= itemTranslation.Serializer; - - // make sure all items are serialized using the same serializer - if (!itemTranslation.Serializer.Equals(itemSerializer)) - { - throw new ExpressionNotSupportedException(expression, because: "all items in the array must be serialized using the same serializer"); - } } - var ast = AstExpression.ComputedArray(items); - var arrayType = expression.Type; - var itemType = arrayType.GetElementType(); - itemSerializer ??= BsonSerializer.LookupSerializer(itemType); // if the array is empty itemSerializer will be null - var arraySerializerType = typeof(ArraySerializer<>).MakeGenericType(itemType); - var arraySerializer = (IBsonSerializer)Activator.CreateInstance(arraySerializerType, itemSerializer); - + var arraySerializer = context.NodeSerializers.GetSerializer(expression); return new TranslatedExpression(expression, ast, arraySerializer); } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs index aee174ac38d..af7b324c2f3 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs @@ -39,34 +39,21 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr var collectionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, collectionExpression); var itemSerializer = ArraySerializerHelper.GetItemSerializer(collectionTranslation.Serializer); - IBsonSerializer keySerializer; - IBsonSerializer valueSerializer; AstExpression collectionTranslationAst; - if (itemSerializer is IBsonDocumentSerializer itemDocumentSerializer) + if (itemSerializer.IsKeyValuePairSerializer(out var keyElementName, out var valueElementName, out var keySerializer, out var valueSerializer)) { - if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Key", out var keyMemberSerializationInfo)) - { - throw new ExpressionNotSupportedException(expression, because: $"serializer class {itemSerializer.GetType()} does not have a Key member"); - } - keySerializer = keyMemberSerializationInfo.Serializer; - - if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Value", out var valueMemberSerializationInfo)) - { - throw new ExpressionNotSupportedException(expression, because: $"serializer class {itemSerializer.GetType()} does not have a Value member"); - } - valueSerializer = valueMemberSerializationInfo.Serializer; - - if (keyMemberSerializationInfo.ElementName == "k" && valueMemberSerializationInfo.ElementName == "v") + if (keyElementName == "k" && valueElementName == "v") { collectionTranslationAst = collectionTranslation.Ast; } else { + // map keyElementName and valueElementName to "k" and "v" var pairVar = AstExpression.Var("pair"); var computedDocumentAst = AstExpression.ComputedDocument([ - AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyMemberSerializationInfo.ElementName)), - AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueMemberSerializationInfo.ElementName)) + AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyElementName)), + AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueElementName)) ]); collectionTranslationAst = AstExpression.Map(collectionTranslation.Ast, pairVar, computedDocumentAst); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewKeyValuePairExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewKeyValuePairExpressionToAggregationExpressionTranslator.cs index cfe4f67f6a8..5eaf71255b7 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewKeyValuePairExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewKeyValuePairExpressionToAggregationExpressionTranslator.cs @@ -13,13 +13,12 @@ * limitations under the License. */ -using System; using System.Collections.Generic; using System.Linq.Expressions; using MongoDB.Bson; -using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewListExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewListExpressionToAggregationExpressionTranslator.cs index 3063460c00b..68acaa92950 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewListExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewListExpressionToAggregationExpressionTranslator.cs @@ -34,7 +34,7 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr { var argument = arguments[0]; var argumentType = argument.Type; - if (argumentType.IsConstructedGenericType && argumentType.GetGenericTypeDefinition().Implements(typeof(IEnumerable<>))) + if (argumentType.IsConstructedGenericType && argumentType.GetGenericTypeDefinition().ImplementsInterface(typeof(IEnumerable<>))) { var enumerableInterface = argumentType.GetIEnumerableGenericInterface(); var argumentItemType = enumerableInterface.GetGenericArguments()[0]; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewTupleExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewTupleExpressionToAggregationExpressionTranslator.cs index 66235aa71ca..a94d7fc3817 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewTupleExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewTupleExpressionToAggregationExpressionTranslator.cs @@ -13,13 +13,11 @@ * limitations under the License. */ -using System; -using System.Collections.Generic; using System.Linq.Expressions; using MongoDB.Bson.Serialization; -using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -49,16 +47,11 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr } var ast = AstExpression.ComputedArray(items); - var tupleSerializer = CreateTupleSerializer(tupleType, itemSerializers); + var tupleSerializer = TupleOrValueTupleSerializer.Create(tupleType, itemSerializers); return new TranslatedExpression(expression, ast, tupleSerializer); } throw new ExpressionNotSupportedException(expression); } - - private static IBsonSerializer CreateTupleSerializer(Type tupleType, IEnumerable itemSerializers) - { - return tupleType.IsTuple() ? TupleSerializer.Create(itemSerializers) : ValueTupleSerializer.Create(itemSerializers); - } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NotExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NotExpressionToAggregationExpressionTranslator.cs index 692b3600ddd..486ae382721 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NotExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NotExpressionToAggregationExpressionTranslator.cs @@ -24,6 +24,7 @@ public static TranslatedExpression Translate(TranslationContext context, UnaryEx { if (expression.NodeType == ExpressionType.Not) { + // TODO: check operand representation var operandTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, expression.Operand); var ast = expression.Type == typeof(bool) ? AstExpression.Not(operandTranslation.Ast) : AstExpression.BitNot(operandTranslation.Ast); return new TranslatedExpression(expression, ast, operandTranslation.Serializer); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs index a6a89b7639f..00a7780eb66 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs @@ -13,6 +13,8 @@ * limitations under the License. */ +using System; +using System.Linq; using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; @@ -31,7 +33,7 @@ public static ExecutableQuery> Translate TranslateScalar(containingExpression: compareMethodCallExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/AllWithContainsInPredicateMethodToFilterTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/AllWithContainsInPredicateMethodToFilterTranslator.cs index cabda2421f0..3ae6b57d249 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/AllWithContainsInPredicateMethodToFilterTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/AllWithContainsInPredicateMethodToFilterTranslator.cs @@ -64,7 +64,7 @@ public static AstFilter Translate(TranslationContext context, Expression arrayFi private static bool IsContainsParameterExpression(Expression predicateBody, ParameterExpression predicateParameter, out Expression innerSourceExpression) { if (predicateBody is MethodCallExpression methodCallExpression && - IsContainsMethodCall(methodCallExpression, out var sourceExpression, out var valueExpression) && + EnumerableMethod.IsContainsMethod(methodCallExpression, out var sourceExpression, out var valueExpression) && valueExpression == predicateParameter) { innerSourceExpression = sourceExpression; @@ -73,49 +73,6 @@ private static bool IsContainsParameterExpression(Expression predicateBody, Para innerSourceExpression = null; return false; - - static bool IsContainsMethodCall(MethodCallExpression methodCallExpression, out Expression sourceExpression, out Expression valueExpression) - { - var method = methodCallExpression.Method; - var arguments = methodCallExpression.Arguments; - - if (method.Name == "Contains" && method.ReturnType == typeof(bool)) - { - if (method.IsStatic && arguments.Count == 2) - { - sourceExpression = arguments[0]; - valueExpression = arguments[1]; - if (ValueTypeIsElementTypeOfSourceType(valueExpression, sourceExpression)) - { - return true; - } - } - else if (!method.IsStatic && arguments.Count == 1) - { - sourceExpression = methodCallExpression.Object; - valueExpression = arguments[0]; - if (ValueTypeIsElementTypeOfSourceType(valueExpression, sourceExpression)) - { - return true; - } - } - } - - sourceExpression = null; - valueExpression = null; - return false; - } - - static bool ValueTypeIsElementTypeOfSourceType(Expression valueExpression, Expression sourceExpression) - { - if (sourceExpression.Type.TryGetIEnumerableGenericInterface(out var ienumerableInterface)) - { - var elementType = ienumerableInterface.GetGenericArguments()[0]; - return elementType.IsAssignableFrom(valueExpression.Type); - } - - return false; - } } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsMethodToFilterTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsMethodToFilterTranslator.cs index 574a93809c6..3dd10aec451 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsMethodToFilterTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsMethodToFilterTranslator.cs @@ -20,6 +20,7 @@ using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters; using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.ToFilterFieldTranslators; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.MethodTranslators @@ -36,30 +37,11 @@ public static AstFilter Translate(TranslationContext context, MethodCallExpressi var method = expression.Method; var arguments = expression.Arguments; - if (method.IsStatic && - method.Name == "Contains" && - method.ReturnType == typeof(bool) && - arguments.Count == 2) + if (EnumerableMethod.IsContainsMethod(expression, out var fieldExpression, out var itemExpression)) { - var fieldExpression = arguments[0]; var fieldType = fieldExpression.Type; - var itemExpression = arguments[1]; var itemType = itemExpression.Type; - if (TypeImplementsIEnumerable(fieldType, itemType)) - { - return Translate(context, expression, fieldExpression, itemExpression); - } - } - if (!method.IsStatic && - method.Name == "Contains" && - method.ReturnType == typeof(bool) && - arguments.Count == 1) - { - var fieldExpression = expression.Object; - var fieldType = fieldExpression.Type; - var itemExpression = arguments[0]; - var itemType = itemExpression.Type; if (TypeImplementsIEnumerable(fieldType, itemType)) { return Translate(context, expression, fieldExpression, itemExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs index 7a734d7a075..86f0ce6372f 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs @@ -148,7 +148,7 @@ private static TranslatedFilterField TranslateConvertEnumToUnderlyingType(Transl enumSerializer = fieldSerializer; } - var targetSerializer = EnumUnderlyingTypeSerializer.Create(enumSerializer); + var targetSerializer = AsEnumUnderlyingTypeSerializer.Create(enumSerializer); if (targetType.IsNullable()) { targetSerializer = NullableSerializer.Create(targetSerializer); @@ -186,7 +186,7 @@ private static TranslatedFilterField TranslateConvertUnderlyingTypeToEnum(Transl } IBsonSerializer targetSerializer; - if (valueSerializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) + if (valueSerializer is IAsEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) { targetSerializer = enumUnderlyingTypeSerializer.EnumSerializer; } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/ConcatMethodToPipelineTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/ConcatMethodToPipelineTranslator.cs index 03fb1ecb1b7..5512bc1e44b 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/ConcatMethodToPipelineTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/ConcatMethodToPipelineTranslator.cs @@ -20,6 +20,7 @@ using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToPipelineTranslators { @@ -44,7 +45,7 @@ secondProvider.CollectionNamespace is var secondCollectionNamespace && secondCollectionNamespace != null) { var secondCollectionName = secondCollectionNamespace.CollectionName; - var secondContext = TranslationContext.Create(context.TranslationOptions); + var secondContext = TranslationContext.Create(secondQueryable, context.TranslationOptions); var secondPipeline = ExpressionToPipelineTranslator.Translate(secondContext, secondQueryable.Expression); if (secondPipeline.Ast.Stages.Count == 0) { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/LookupMethodToPipelineTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/LookupMethodToPipelineTranslator.cs index 9bc144c8875..ba43dcebfdd 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/LookupMethodToPipelineTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/LookupMethodToPipelineTranslator.cs @@ -281,7 +281,13 @@ private static TranslatedPipeline TranslateDocumentsPipelineGeneric= 1) + { + var sourceParameter = parameters[0]; + var sourceParameterType = sourceParameter.ParameterType; + if (sourceParameterType.IsConstructedGenericType) + { + sourceParameterType = sourceParameterType.GetGenericTypeDefinition(); + } + + if (sourceParameterType == typeof(IQueryable) || + sourceParameterType == typeof(IQueryable<>) || + sourceParameterType == typeof(IOrderedQueryable) || + sourceParameterType == typeof(IOrderedQueryable<>)) + { + return GetUltimateSource(methodCallExpression.Arguments[0]); + } + } + + throw new ArgumentException($"No ultimate source found: {expression}."); } #endregion // private fields private readonly TranslationContextData _data; + private readonly IReadOnlySerializerMap _nodeSerializers; private readonly NameGenerator _nameGenerator; private readonly SymbolTable _symbolTable; private readonly ExpressionTranslationOptions _translationOptions; private TranslationContext( ExpressionTranslationOptions translationOptions, + IReadOnlySerializerMap nodeSerializers, TranslationContextData data, SymbolTable symbolTable, NameGenerator nameGenerator) { _translationOptions = translationOptions ?? new ExpressionTranslationOptions(); + _nodeSerializers = Ensure.IsNotNull(nodeSerializers, nameof(nodeSerializers)); _data = data; // can be null _symbolTable = Ensure.IsNotNull(symbolTable, nameof(symbolTable)); _nameGenerator = Ensure.IsNotNull(nameGenerator, nameof(nameGenerator)); @@ -54,6 +146,7 @@ private TranslationContext( // public properties public TranslationContextData Data => _data; + public IReadOnlySerializerMap NodeSerializers => _nodeSerializers; public NameGenerator NameGenerator => _nameGenerator; public SymbolTable SymbolTable => _symbolTable; public ExpressionTranslationOptions TranslationOptions => _translationOptions; @@ -99,6 +192,11 @@ public Symbol CreateSymbolWithVarName(ParameterExpression parameter, string varN return CreateSymbol(parameter, name: parameterName, varName, serializer, isCurrent); } + public IBsonSerializer GetSerializer(Expression parameter) + { + return _nodeSerializers.GetSerializer(parameter); + } + public override string ToString() { return $"{{ SymbolTable : {_symbolTable} }}"; @@ -124,7 +222,7 @@ public TranslationContext WithSymbols(params Symbol[] newSymbols) public TranslationContext WithSymbolTable(SymbolTable symbolTable) { - return new TranslationContext(_translationOptions, _data, symbolTable, _nameGenerator); + return new TranslationContext(_translationOptions, _nodeSerializers, _data, symbolTable, _nameGenerator); } } } diff --git a/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs b/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs index 4da85f6bd6d..57ce448f981 100644 --- a/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs +++ b/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs @@ -14,17 +14,16 @@ */ using System; -using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using MongoDB.Bson; using MongoDB.Bson.Serialization; -using MongoDB.Driver; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Linq.Linq3Implementation; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Optimizers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; using MongoDB.Driver.Linq.Linq3Implementation.Translators; using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators; using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators; @@ -61,7 +60,8 @@ internal static BsonValue TranslateExpressionToAggregateExpression>)LinqExpressionPreprocessor.Preprocess(expression); - var context = TranslationContext.Create(translationOptions, contextData); + var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: sourceSerializer, translationOptions: translationOptions, data: contextData); var translation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, expression, sourceSerializer, asRoot: true); var simplifiedAst = AstSimplifier.Simplify(translation.Ast); @@ -76,7 +76,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField( { expression = (LambdaExpression)LinqExpressionPreprocessor.Preprocess(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(translationOptions); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: documentSerializer, translationOptions: translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var body = RemovePossibleConvertToObject(expression.Body); @@ -106,7 +106,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField>)LinqExpressionPreprocessor.Preprocess(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(translationOptions); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: documentSerializer, translationOptions: translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var fieldTranslation = ExpressionToFilterFieldTranslator.Translate(context, expression.Body); @@ -125,8 +125,8 @@ internal static BsonDocument TranslateExpressionToElemMatchFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); - var context = TranslationContext.Create(translationOptions); var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: elementSerializer, translationOptions: translationOptions); var symbol = context.CreateSymbol(parameter, "@", elementSerializer); // @ represents the implied element context = context.WithSingleSymbol(symbol); // @ is the only symbol visible inside an $elemMatch var filter = ExpressionToFilterTranslator.Translate(context, expression.Body, exprOk: false); @@ -142,7 +142,8 @@ internal static BsonDocument TranslateExpressionToFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); - var context = TranslationContext.Create(translationOptions); + var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: documentSerializer, translationOptions: translationOptions); var filter = ExpressionToFilterTranslator.TranslateLambda(context, expression, documentSerializer, asRoot: true); filter = AstSimplifier.SimplifyAndConvert(filter); @@ -176,7 +177,8 @@ private static RenderedProjectionDefinition TranslateExpressionToProjec } expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); - var context = TranslationContext.Create(translationOptions); + var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var simplifier = forFind ? new AstFindProjectionSimplifier() : new AstSimplifier(); try @@ -215,8 +217,18 @@ internal static BsonDocument TranslateExpressionToSetStage( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - var context = TranslationContext.Create(translationOptions); // do not partially evaluate expression var parameter = expression.Parameters.Single(); + var body = expression.Body; + + var nodeSerializers = new SerializerMap(); + nodeSerializers.AddSerializer(parameter, documentSerializer); + if (body.Type == typeof(TDocument)) + { + nodeSerializers.AddSerializer(body, documentSerializer); + } + SerializerFinder.FindSerializers(expression, translationOptions, nodeSerializers); + + var context = TranslationContext.Create(translationOptions, nodeSerializers); // do not partially evaluate expression var symbol = context.CreateRootSymbol(parameter, documentSerializer); context = context.WithSymbol(symbol); var setStage = ExpressionToSetStageTranslator.Translate(context, documentSerializer, expression); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2472Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2472Tests.cs index 035bba42f7e..fe4117ce430 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2472Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2472Tests.cs @@ -17,6 +17,8 @@ using System.Collections.Generic; using System.Linq; using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.TestHelpers; using Xunit; @@ -79,7 +81,7 @@ public class C private class MyDTO { public DateTime timestamp { get; set; } - public decimal sqrt_calc { get; set; } + [BsonRepresentation(BsonType.Decimal128)] public decimal sqrt_calc { get; set; } } public sealed class ClassFixture : MongoCollectionFixture diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4054Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4054Tests.cs index 524b72ff602..5177538f54c 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4054Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4054Tests.cs @@ -43,6 +43,17 @@ from movieId in person.MovieIds join movie in movies.AsQueryable() on movieId equals movie.Id select new { person, movie }; + // equivalement method call syntax + // var queryable = people.AsQueryable() + // .SelectMany( + // person => person.MovieIds, + // (person, movieId) => new { person = person, movieId = movieId }) + // .Join( + // movies.AsQueryable(), + // transparentIdentifier => transparentIdentifier.movieId, + // movie => movie.Id, + // (transparentIdentifier, movie) => new { person = transparentIdentifier.person, movie = movie }); + var stages = Translate(people, queryable); AssertStages( stages, diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4593Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4593Tests.cs new file mode 100644 index 00000000000..0224126f84d --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4593Tests.cs @@ -0,0 +1,144 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using FluentAssertions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4593Tests : LinqIntegrationTest +{ + public CSharp4593Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void First_example_should_work() + { + var collection = Fixture.Orders; + + var find = collection + .Find(o => o.RateBasisHistoryId == "abc") + .Project(r => r.Id); + + var translatedFilter = TranslateFindFilter(collection, find); + translatedFilter.Should().Be("{ RateBasisHistoryId : 'abc' }"); + + var translatedProjection = TranslateFindProjection(collection, find); + translatedProjection.Should().Be("{ _id : 1 }"); + + var result = find.Single(); + result.Should().Be("a"); + } + + [Fact] + public void First_example_workaround_should_work() + { + var collection = Fixture.Orders; + + var find = collection + .Find(o => o.RateBasisHistoryId == "abc") + .Project(Builders.Projection.Include(o => o.Id)); + + var translatedFilter = TranslateFindFilter(collection, find); + translatedFilter.Should().Be("{ RateBasisHistoryId : 'abc' }"); + + var translatedProjection = TranslateFindProjection(collection, find); + translatedProjection.Should().Be("{ _id : 1 }"); + + var result = find.Single(); + result["_id"].AsString.Should().Be("a"); + } + + [Fact] + public void Second_example_should_work() + { + var collection = Fixture.Entities; + var idsFilter = Builders.Filter.Eq(x => x.Id, 1); + + var aggregate = collection.Aggregate() + .Match(idsFilter) + .Project(e => new + { + _id = e.Id, + CampaignId = e.CampaignId, + Accepted = e.Status.Key == "Accepted" ? 1 : 0, + Rejected = e.Status.Key == "Rejected" ? 1 : 0, + }); + + var stages = Translate(collection, aggregate); + AssertStages( + stages, + "{ $match : { _id : 1 } }", + """ + { $project : + { + _id : "$_id", + CampaignId : "$CampaignId", + Accepted : { $cond : { if : { $eq : ["$Status.Key", "Accepted"] }, then : 1, else : 0 } }, + Rejected : { $cond : { if : { $eq : ["$Status.Key", "Rejected"] }, then : 1, else : 0 } } + } + } + """); + + var results = aggregate.ToList(); + results.Count.Should().Be(1); + results[0]._id.Should().Be(1); + results[0].CampaignId.Should().Be(11); + results[0].Accepted.Should().Be(1); + results[0].Rejected.Should().Be(0); + } + + public class Order + { + public string Id { get; set; } + public string RateBasisHistoryId { get; set; } + } + + public class Entity + { + public int Id { get; set; } + public int CampaignId { get; set; } + public Status Status { get; set; } + } + + public class Status + { + public string Key { get; set; } + } + + public sealed class ClassFixture : MongoDatabaseFixture + { + public IMongoCollection Orders { get; private set; } + public IMongoCollection Entities { get; private set; } + + protected override void InitializeFixture() + { + Orders = CreateCollection("orders"); + Orders.InsertMany( + [ + new Order { Id = "a", RateBasisHistoryId = "abc" } + ]); + + Entities = CreateCollection("entities"); + Entities.InsertMany( + [ + new Entity { Id = 1, CampaignId = 11, Status = new Status { Key = "Accepted" } }, + new Entity { Id = 2, CampaignId = 22, Status = new Status { Key = "Rejected" } } + ]); + } + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4708Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4708Tests.cs index 2164f38e6a0..4225bef829f 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4708Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4708Tests.cs @@ -355,7 +355,7 @@ public void Where_Document_item_with_int_using_call_to_get_item_should_work() Expression.Property(x, typeof(C).GetProperty("Document")), typeof(BsonDocument).GetProperty("Item", new[] { typeof(int) }).GetGetMethod(), Expression.Constant(0)), - Expression.Constant(BsonValue.Create(1))); + Expression.Constant(BsonValue.Create(1), typeof(BsonValue))); var parameters = new ParameterExpression[] { x }; var predicate = Expression.Lambda>(body, parameters); @@ -379,7 +379,7 @@ public void Where_Document_item_with_int_using_MakeIndex_should_work() Expression.Property(x, typeof(C).GetProperty("Document")), typeof(BsonDocument).GetProperty("Item", new[] { typeof(int) }), new Expression[] { Expression.Constant(0) }), - Expression.Constant(BsonValue.Create(1))); + Expression.Constant(BsonValue.Create(1), typeof(BsonValue))); var parameters = new ParameterExpression[] { x }; var predicate = Expression.Lambda>(body, parameters); @@ -418,7 +418,7 @@ public void Where_Document_item_with_string_using_call_to_get_item_should_work() Expression.Property(x, typeof(C).GetProperty("Document")), typeof(BsonDocument).GetProperty("Item", new[] { typeof(string) }).GetGetMethod(), Expression.Constant("a")), - Expression.Constant(BsonValue.Create(1))); + Expression.Constant(BsonValue.Create(1), typeof(BsonValue))); var parameters = new ParameterExpression[] { x }; var predicate = Expression.Lambda>(body, parameters); @@ -442,7 +442,7 @@ public void Where_Document_item_with_string_using_MakeIndex_should_work() Expression.Property(x, typeof(C).GetProperty("Document")), typeof(BsonDocument).GetProperty("Item", new[] { typeof(string) }), new Expression[] { Expression.Constant("a") }), - Expression.Constant(BsonValue.Create(1))); + Expression.Constant(BsonValue.Create(1), typeof(BsonValue))); var parameters = new ParameterExpression[] { x }; var predicate = Expression.Lambda>(body, parameters); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4819Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4819Tests.cs new file mode 100644 index 00000000000..9f8f49eff4e --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4819Tests.cs @@ -0,0 +1,68 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4819Tests : LinqIntegrationTest +{ + public CSharp4819Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void ReplaceWith_should_use_configured_element_name() + { + var collection = Fixture.Collection; + var stage = PipelineStageDefinitionBuilder + .ReplaceWith((User u) => new User { UserId = u.UserId }); + + var aggregate = collection.Aggregate() + .AppendStage(stage); + + var stages = Translate(collection, aggregate); + AssertStages( + stages, + "{ $replaceWith : { uuid : '$uuid' } }"); + + var result = aggregate.Single(); + result.Id.Should().Be(0); + result.UserId.Should().Be(Guid.Parse("00112233-4455-6677-8899-aabbccddeeff")); + } + + public class User + { + public int Id { get; set; } + [BsonElement("uuid")] + [BsonGuidRepresentation(GuidRepresentation.Standard)] + public Guid UserId { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new User { Id = 1, UserId = Guid.Parse("00112233-4455-6677-8899-aabbccddeeff") } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4820Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4820Tests.cs new file mode 100644 index 00000000000..18be97f693c --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4820Tests.cs @@ -0,0 +1,114 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4820Tests : LinqIntegrationTest +{ + public CSharp4820Tests(ClassFixture fixture) + : base(fixture) + { + } + + static CSharp4820Tests() + { + BsonClassMap.RegisterClassMap(cm => + { + cm.AutoMap(); + var readonlyCollectionMemberMap = cm.GetMemberMap(x => x.ReadOnlyCollection); + var readOnlyCollectionSerializer = readonlyCollectionMemberMap.GetSerializer(); + var bracketingCollectionSerializer = ((IChildSerializerConfigurable)readOnlyCollectionSerializer).WithChildSerializer(new StringBracketingSerializer()); + readonlyCollectionMemberMap.SetSerializer(bracketingCollectionSerializer); + }); + } + + [Fact] + public void Update_Set_with_List_should_work() + { + var values = new List() { "abc", "def" }; + var update = Builders.Update.Set(x => x.ReadOnlyCollection, values); + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var rendered = (BsonDocument)update.Render(new (documentSerializer, serializerRegistry)); + + rendered.Should().Be("{ $set : { ReadOnlyCollection : ['[abc]', '[def]'] } }"); + } + + [Fact] + public void Update_Set_with_Enumerable_should_throw() + { + var values = new[] { "abc", "def" }.Select(x => x); + var update = Builders.Update.Set(x => x.ReadOnlyCollection, values); + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var rendered = (BsonDocument)update.Render(new (documentSerializer, serializerRegistry)); + + rendered.Should().Be("{ $set : { ReadOnlyCollection : ['[abc]', '[def]'] } }"); + } + + [Fact] + public void Update_Set_with_Enumerable_ToList_should_work() + { + var values = new[] { "abc", "def" }.Select(x => x); + var update = Builders.Update.Set(x => x.ReadOnlyCollection, values.ToList()); + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var rendered = (BsonDocument)update.Render(new (documentSerializer, serializerRegistry)); + + rendered.Should().Be("{ $set : { ReadOnlyCollection : ['[abc]', '[def]'] } }"); + } + + public class C + { + public int Id { get; set; } + public IReadOnlyCollection ReadOnlyCollection { get; set; } + } + + + private class StringBracketingSerializer : SerializerBase + { + public override string Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var bracketedValue = StringSerializer.Instance.Deserialize(context, args); + return bracketedValue.Substring(1, bracketedValue.Length - 2); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, string value) + { + var bracketedValue = "[" + value + "]"; + StringSerializer.Instance.Serialize(context, bracketedValue); + } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => null; + // [ + // new C { } + // ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4957Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4957Tests.cs index 791ce3bcd75..e82194ef6cc 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4957Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4957Tests.cs @@ -84,7 +84,7 @@ public void New_array_with_two_items_should_work() [Theory] [ParameterAttributeData] - public void New_array_with_two_items_with_different_serializers_should_throw( + public void New_array_with_two_items_with_different_serializers_should_work( [Values(false, true)] bool enableClientSideProjections) { RequireServer.Check().Supports(Feature.FindProjectionExpressions); @@ -94,21 +94,11 @@ public void New_array_with_two_items_with_different_serializers_should_throw( var queryable = collection.AsQueryable(translationOptions) .Select(x => new[] { x.X, x.Y }); - if (enableClientSideProjections) - { - var stages = Translate(collection, queryable, out var outputSerializer); - AssertStages(stages, "{ $project : { _snippets : ['$X', '$Y'], _id : 0 } }"); - outputSerializer.Should().BeAssignableTo(); - - var result = queryable.Single(); - result.Should().Equal(1, 2); - } - else - { - var exception = Record.Exception(() => Translate(collection, queryable)); - exception.Should().BeOfType(); - exception.Message.Should().Contain("all items in the array must be serialized using the same serializer"); - } + var stages = Translate(collection, queryable, out var outputSerializer); + AssertStages(stages, "{ $project : { _v : ['$X', '$Y'], _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(1, 2); } public class C diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4967Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4967Tests.cs new file mode 100644 index 00000000000..a93e1b4f387 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4967Tests.cs @@ -0,0 +1,75 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson.Serialization; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4967Tests : LinqIntegrationTest +{ + public CSharp4967Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void Set_Nested_should_work() + { + var collection = Fixture.Collection; + var update = Builders.Update + .Pipeline(new EmptyPipelineDefinition() + .Set(c => new MyDocument + { + Nested = new MyNestedDocument + { + ValueCopy = c.Value, + }, + })); + + var renderedUpdate = update.Render(new(collection.DocumentSerializer, BsonSerializer.SerializerRegistry)).AsBsonArray; + renderedUpdate.Count.Should().Be(1); + renderedUpdate[0].Should().Be("{ $set : { Nested : { ValueCopy : '$Value' } } }"); + + collection.UpdateMany("{ }", update); + + var updatedDocument = collection.FindSync("{}").Single(); + updatedDocument.Nested.ValueCopy.Should().Be("Value"); + } + + public class MyDocument + { + public int Id { get; set; } + public string Value { get; set; } + public string AnotherValue { get; set; } + public MyNestedDocument Nested { get; set; } + } + + public class MyNestedDocument + { + public string ValueCopy { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new MyDocument { Id = 1, Value = "Value" } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs new file mode 100644 index 00000000000..d188c18fa1f --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs @@ -0,0 +1,225 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq; +using MongoDB.Bson; +using MongoDB.Bson.IO; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Bson.Serialization.Serializers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp5435Tests : Linq3IntegrationTest + { + [Fact] + public void Test_set_ValueObject_Value_using_creator_map() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue(x.ValueObject == null ? 1 : x.ValueObject.Value + 1) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_ValueObject_Value_using_property_setter() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1 + } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_ValueObject_to_derived_value_using_property_setter() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyDerivedValue() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1, + B = 42 + } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_X_using_constructor() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + X = new X(x.Y) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { X : { Y : '$Y' } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_A() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + A = new [] { 2, x.A[0] } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { A : ['2', { $arrayElemAt : ['$A', 0] }] } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + private IMongoCollection GetCollection() + { + var collection = GetCollection("test"); + CreateCollection( + collection.Database.GetCollection("test"), + BsonDocument.Parse("{ _id : 1 }"), + BsonDocument.Parse("{ _id : 2, X : null }"), + BsonDocument.Parse("{ _id : 3, X : 3 }")); + return collection; + } + + class MyDocument + { + [BsonRepresentation(MongoDB.Bson.BsonType.ObjectId)] + public string Id { get; set; } = ObjectId.GenerateNewId().ToString(); + + public MyValue ValueObject { get; set; } + + public long Long { get; set; } + + public X X { get; set; } + + public int Y { get; set; } + + [BsonRepresentation(BsonType.String)] + public int[] A { get; set; } + } + + class MyValue + { + [BsonConstructor] + public MyValue() { } + [BsonConstructor] + public MyValue(int value) { Value = value; } + public int Value { get; set; } + } + + class MyDerivedValue : MyValue + { + public int B { get; set; } + } + + [BsonSerializer(typeof(XSerializer))] + class X + { + public X(int y) + { + Y = y; + } + public int Y { get; } + } + + class XSerializer : SerializerBase, IBsonDocumentSerializer + { + public override X Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var reader = context.Reader; + reader.ReadStartArray(); + _ = reader.ReadName(); + var y = reader.ReadInt32(); + reader.ReadEndDocument(); + + return new X(y); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, X value) + { + var writer = context.Writer; + writer.WriteStartDocument(); + writer.WriteName("Y"); + writer.WriteInt32(value.Y); + writer.WriteEndDocument(); + } + + public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) + { + serializationInfo = memberName == "Y" ? new BsonSerializationInfo("Y", Int32Serializer.Instance, typeof(int)) : null; + return serializationInfo != null; + } + } + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5519Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5519Tests.cs new file mode 100644 index 00000000000..30f3a73072a --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5519Tests.cs @@ -0,0 +1,66 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp5519Tests : LinqIntegrationTest +{ + public CSharp5519Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void Array_constant_Any_should_serialize_array_correctly() + { + var collection = Fixture.Collection; + var array = new[] { E.A, E.B }; + + var find = collection.Find(x => array.Any(e => x.E == e)); + + var filter = TranslateFindFilter(collection, find); + filter.Should().Be("{ E : { $in : ['A', 'B'] } }"); + + var results = find.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2); + } + + public class C + { + public int Id { get; set; } + [BsonRepresentation(BsonType.String)] public E E { get; set; } + } + + public enum E { A, B, C } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new C { Id = 1, E = E.A }, + new C { Id = 2, E = E.B }, + new C { Id = 3, E = E.C } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5532Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5532Tests.cs new file mode 100644 index 00000000000..10c9e294982 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5532Tests.cs @@ -0,0 +1,199 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver.Core.Misc; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp5532Tests : LinqIntegrationTest +{ + private static readonly ObjectId id1 = ObjectId.Parse("111111111111111111111111"); + private static readonly ObjectId id2 = ObjectId.Parse("222222222222222222222222"); + private static readonly ObjectId id3 = ObjectId.Parse("333333333333333333333333"); + + public CSharp5532Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void Filter_should_translate_correctly() + { + var collection = Fixture.Collection; + List jobIds = [id2.ToString()]; + + var find = collection + .Find(x => x.Parts.Any(a => a.Refs.Any(b => jobIds.Contains(b.id)))); + + var filter = TranslateFindFilter(collection, find); + + filter.Should().Be("{ Parts : { $elemMatch : { Refs : { $elemMatch : { _id : { $in : [ObjectId('222222222222222222222222')] } } } } } }"); + } + + [Fact] + public void Projection_should_translate_correctly() + { + var collection = Fixture.Collection; + List jobIds = [id2.ToString()]; + + var find = collection + .Find("{}") + .Project(chain => + new + { + chain.Parts + .First(p => p.Refs.Any(j => jobIds.Contains(j.id))) + .Refs.First(j => jobIds.Contains(j.id)).id + });; + + var projectionTranslation = TranslateFindProjection(collection, find); + + var expectedTranslation = + """ + { + _id : + { + $let : + { + vars : + { + this : + { + $arrayElemAt : + [ + { + $filter : + { + input : + { + $let : + { + vars : + { + this : + { + $arrayElemAt : + [ + { + $filter : + { + input : "$Parts", + as : "p", + cond : + { + $anyElementTrue : + { + $map : + { + input : "$$p.Refs", + as : "j", + in : { $in : ["$$j._id", [{ "$oid" : "222222222222222222222222" }]] } + } + } + }, + limit : 1 + } + }, + 0 + ] + } + }, + in : "$$this.Refs" + } + }, + as : "j", + cond : { $in : ['$$j._id', [{ "$oid" : "222222222222222222222222" }]] }, + limit : 1 + } + }, + 0 + ] + } + }, + in : "$$this._id" + } + } + } + """; + if (!Feature.FilterLimit.IsSupported(CoreTestConfiguration.MaxWireVersion)) + { + expectedTranslation = Regex.Replace(expectedTranslation, @",\s+limit : 1", ""); + } + + projectionTranslation.Should().Be(expectedTranslation); + } + + public class Document + { + [BsonId] + [BsonRepresentation(BsonType.ObjectId)] + public string id { get; set; } + } + + public class Chain : Document + { + public ICollection Parts { get; set; } = new List(); + } + + public class Unit + { + public ICollection Refs { get; set; } + + public Unit() + { + Refs = new List(); + } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new Chain + { + id = "0102030405060708090a0b0c", + Parts = new List() + { + new() + { + Refs = new List() + { + new() + { + id = id1.ToString(), + }, + new() + { + id = id2.ToString(), + }, + new() + { + id = id3.ToString(), + }, + } + } + } + } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializerTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializerTests.cs similarity index 68% rename from tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializerTests.cs rename to tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializerTests.cs index f6f7ace6d48..de10de29a04 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializerTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializerTests.cs @@ -22,7 +22,7 @@ namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Serializers { - public class EnumUnderlyingTypeSerializerTests + public class AsEnumUnderlyingTypeSerializerTests { private static readonly IBsonSerializer __enumSerializer1 = new ESerializer1(); private static readonly IBsonSerializer __enumSerializer2 = new ESerializer2(); @@ -30,8 +30,8 @@ public class EnumUnderlyingTypeSerializerTests [Fact] public void Equals_derived_should_return_false() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); - var y = new DerivedFromEnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); + var y = new DerivedFromAsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.Equals(y); @@ -41,7 +41,7 @@ public void Equals_derived_should_return_false() [Fact] public void Equals_null_should_return_false() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.Equals(null); @@ -51,7 +51,7 @@ public void Equals_null_should_return_false() [Fact] public void Equals_object_should_return_false() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var y = new object(); var result = x.Equals(y); @@ -62,7 +62,7 @@ public void Equals_object_should_return_false() [Fact] public void Equals_self_should_return_true() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.Equals(x); @@ -72,8 +72,8 @@ public void Equals_self_should_return_true() [Fact] public void Equals_with_equal_fields_should_return_true() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); - var y = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); + var y = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.Equals(y); @@ -83,8 +83,8 @@ public void Equals_with_equal_fields_should_return_true() [Fact] public void Equals_with_not_equal_field_should_return_false() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); - var y = new EnumUnderlyingTypeSerializer(__enumSerializer2); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); + var y = new AsEnumUnderlyingTypeSerializer(__enumSerializer2); var result = x.Equals(y); @@ -94,18 +94,18 @@ public void Equals_with_not_equal_field_should_return_false() [Fact] public void GetHashCode_should_return_zero() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.GetHashCode(); result.Should().Be(0); } - internal class DerivedFromEnumUnderlyingTypeSerializer : EnumUnderlyingTypeSerializer + internal class DerivedFromAsEnumUnderlyingTypeSerializer : AsEnumUnderlyingTypeSerializer where TEnum : Enum where TEnumUnderlyingType : struct { - public DerivedFromEnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) : base(enumSerializer) + public DerivedFromAsEnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) : base(enumSerializer) { } } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs index 10f3f2a5d14..08f071902ab 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs @@ -31,8 +31,8 @@ public class ModuloComparisonExpressionToFilterTranslatorTests [Fact] public void Translate_should_return_expected_result_with_byte_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Byte % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Byte % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -44,8 +44,8 @@ public void Translate_should_return_expected_result_with_byte_arguments() [Fact] public void Translate_should_return_expected_result_with_decimal_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Decimal % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Decimal % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -57,8 +57,8 @@ public void Translate_should_return_expected_result_with_decimal_arguments() [Fact] public void Translate_should_return_expected_result_with_double_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Double % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Double % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -70,8 +70,8 @@ public void Translate_should_return_expected_result_with_double_arguments() [Fact] public void Translate_should_return_expected_result_with_float_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Float % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Float % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -83,8 +83,8 @@ public void Translate_should_return_expected_result_with_float_arguments() [Fact] public void Translate_should_return_expected_result_with_int16_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Int16 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Int16 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -96,8 +96,8 @@ public void Translate_should_return_expected_result_with_int16_arguments() [Fact] public void Translate_should_return_expected_result_with_int32_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Int32 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Int32 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -109,8 +109,8 @@ public void Translate_should_return_expected_result_with_int32_arguments() [Fact] public void Translate_should_return_expected_result_with_int64_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Int64 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Int64 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -122,8 +122,8 @@ public void Translate_should_return_expected_result_with_int64_arguments() [Fact] public void Translate_should_return_expected_result_with_sbyte_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.SByte % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.SByte % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -135,8 +135,8 @@ public void Translate_should_return_expected_result_with_sbyte_arguments() [Fact] public void Translate_should_return_expected_result_with_uint16_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.UInt16 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.UInt16 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -148,8 +148,8 @@ public void Translate_should_return_expected_result_with_uint16_arguments() [Fact] public void Translate_should_return_expected_result_with_uint32_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.UInt32 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.UInt32 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -161,8 +161,8 @@ public void Translate_should_return_expected_result_with_uint32_arguments() [Fact] public void Translate_should_return_expected_result_with_uint64_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.UInt64 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.UInt64 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -180,19 +180,19 @@ private void Assert(AstFilter result, string path, BsonValue divisor, BsonValue modFilterOperation.Remainder.Should().Be(remainder); } - private TranslationContext CreateContext(ParameterExpression parameter) + private TranslationContext CreateContext(LambdaExpression lambda) { + var parameter = lambda.Parameters.Single(); var serializer = BsonSerializer.LookupSerializer(parameter.Type); - var context = TranslationContext.Create(translationOptions: null); + var context = TranslationContext.Create(lambda, parameter, serializer, translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); return context.WithSymbol(symbol); } - private (ParameterExpression, BinaryExpression) CreateExpression(Expression> lambda) + private (LambdaExpression, BinaryExpression) CreateExpression(Expression> lambda) { - var parameter = lambda.Parameters.Single(); var expression = (BinaryExpression)lambda.Body; - return (parameter, expression); + return (lambda, expression); } private class C diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs index a8f7428079b..5c8f9e967e1 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs @@ -641,7 +641,7 @@ private ProjectedResult Group(Expression Project(Expression var query = __collection.AsQueryable().Select(projector); var provider = (MongoQueryProvider)query.Provider; + var inputSerializer = (IBsonSerializer)provider.PipelineInputSerializer; + var serializerRegistry = provider.Collection.Settings.SerializerRegistry; var translationOptions = new ExpressionTranslationOptions { EnableClientSideProjections = false }; - var executableQuery = ExpressionToExecutableQueryTranslator.Translate(provider, query.Expression, translationOptions); - var projection = executableQuery.Pipeline.Ast.Stages.First().Render()["$project"].AsBsonDocument; + var renderedProjection = LinqProviderAdapter.TranslateExpressionToProjection( + projector, + inputSerializer, + serializerRegistry, + translationOptions); + + var projection = renderedProjection.Document; var value = query.Take(1).FirstOrDefault(); return new ProjectedResult diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs index fa01543be13..a4c53d6fd03 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs @@ -1184,7 +1184,7 @@ private void Assert(Expression> expression, int var parameter = expression.Parameters.Single(); var serializer = BsonSerializer.LookupSerializer(); - var context = TranslationContext.Create(translationOptions: null); + var context = TranslationContext.Create(expression, parameter, serializer, translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, expression.Body); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs index c96a2a96b9a..2e306b35c5a 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs @@ -1152,9 +1152,9 @@ public List Assert(IMongoCollection collection, { filter = (Expression>)LinqExpressionPreprocessor.Preprocess(filter); - var serializer = BsonSerializer.SerializerRegistry.GetSerializer(); var parameter = filter.Parameters.Single(); - var context = TranslationContext.Create(translationOptions: null); + var serializer = BsonSerializer.SerializerRegistry.GetSerializer(); + var context = TranslationContext.Create(filter, parameter, serializer, translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, filter.Body);