Skip to content

Commit

Permalink
Also simplify a == null ? a : null and various other cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Dec 15, 2024
1 parent d598bfb commit d546033
Show file tree
Hide file tree
Showing 16 changed files with 362 additions and 260 deletions.
1 change: 1 addition & 0 deletions EFCore.sln.DotSettings
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ The .NET Foundation licenses this file to you under the MIT license.
<s:Boolean x:Key="/Default/UserDictionary/Words/=subquery/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=subquery_0027s/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=transactionality/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=uncoalesce/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=uncoalescing/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=unconfigured/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=unequality/@EntryIndexedValue">True</s:Boolean>
Expand Down
62 changes: 24 additions & 38 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,10 @@ public virtual SqlExpression Case(
elseResult = lastCase.ElseResult;
}

// Optimize:
// Simplify:
// a == null ? null : a -> a
// a != null ? a : null -> a
// And lift:
// a == b ? null : a -> NULLIF(a, b)
// a != b ? a : null -> NULLIF(a, b)
if (operand is null
Expand All @@ -835,18 +838,28 @@ public virtual SqlExpression Case(
Test: SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } binary,
Result: var result
}
])
]
&& binary.OperatorType switch
{
ExpressionType.Equal when result is SqlConstantExpression { Value: null } && elseResult is not null => elseResult,
ExpressionType.NotEqual when elseResult is null or SqlConstantExpression { Value: null } => result,
_ => null
} is SqlExpression conditionalResult)
{
switch (binary.OperatorType)
var (left, right) = (binary.Left, binary.Right);

if (left.Equals(conditionalResult))
{
case ExpressionType.Equal
when result is SqlConstantExpression { Value: null }
&& elseResult is not null
&& TryTranslateToNullIf(elseResult, out var nullIfTranslation):
case ExpressionType.NotEqual
when elseResult is null or SqlConstantExpression { Value: null }
&& TryTranslateToNullIf(result, out nullIfTranslation):
return nullIfTranslation;
return right is SqlConstantExpression { Value: null }
? left
: Function("NULLIF", [left, right], nullable: true, [false, false], left.Type, left.TypeMapping);
}

if (right.Equals(conditionalResult))
{
return left is SqlConstantExpression { Value: null }
? right
: Function("NULLIF", [right, left], nullable: true, [false, false], right.Type, right.TypeMapping);
}
}

Expand All @@ -862,33 +875,6 @@ bool IsSkipped(CaseWhenClause clause)

bool IsMatched(CaseWhenClause clause)
=> operand is null && clause.Test is SqlConstantExpression { Value: true };

bool TryTranslateToNullIf(SqlExpression conditionalResult, [NotNullWhen(true)] out SqlExpression? nullIfTranslation)
{
var (left, right) = (binary.Left, binary.Right);

// If one of sides of the equality is equal to the result of the conditional - a == b ? null : a - convert to
// NULLIF(a, b).
// Specifically refrain from doing so for when the other side is a null constant, as that would transform a == null ? null : a
// to NULLIF(a, NULL), which we don't want.

if (left.Equals(conditionalResult) && right is not SqlConstantExpression { Value: null })
{
nullIfTranslation = Function(
"NULLIF", [left, right], true, [false, false], left.Type, left.TypeMapping);
return true;
}

if (right.Equals(conditionalResult) && left is not SqlConstantExpression { Value: null })
{
nullIfTranslation = Function(
"NULLIF", [right, left], true, [false, false], right.Type, right.TypeMapping);
return true;
}

nullIfTranslation = null;
return false;
}
}

/// <inheritdoc />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,66 +175,6 @@ public override async Task TimeSpan_Compare_to_simple_zero(bool async, bool comp

#endregion Compare

#region Uncoalescing conditional / NullIf

public override Task Uncoalescing_conditional_with_equality_left(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Uncoalescing_conditional_with_equality_left(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((c["Int"] = 9) ? null : c["Int"]) > 1)
""");
});

public override Task Uncoalescing_conditional_with_equality_right(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Uncoalescing_conditional_with_equality_right(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((9 = c["Int"]) ? null : c["Int"]) > 1)
""");
});

public override Task Uncoalescing_conditional_with_unequality_left(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Uncoalescing_conditional_with_unequality_left(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((c["Int"] != 9) ? c["Int"] : null) > 1)
""");
});

public override Task Uncoalescing_conditional_with_inequality_right(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Uncoalescing_conditional_with_inequality_right(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((9 != c["Int"]) ? c["Int"] : null) > 1)
""");
});

#endregion Uncoalescing conditional / NullIf

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.EntityFrameworkCore.Query.Translations;

public class OperatorTranslationsCosmosTest : OperatorTranslationsTestBase<BasicTypesQueryCosmosFixture>
{
public OperatorTranslationsCosmosTest(BasicTypesQueryCosmosFixture fixture, ITestOutputHelper testOutputHelper) : base(fixture)
{
Fixture.TestSqlLoggerFactory.Clear();
Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

#region Conditional

public override Task Conditional_uncoalesce_with_equality_left(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_uncoalesce_with_equality_left(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((c["Int"] = 9) ? null : c["Int"]) > 1)
""");
});

public override Task Conditional_uncoalesce_with_equality_right(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_uncoalesce_with_equality_right(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((9 = c["Int"]) ? null : c["Int"]) > 1)
""");
});

public override Task Conditional_uncoalesce_with_unequality_left(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_uncoalesce_with_unequality_left(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((c["Int"] != 9) ? c["Int"] : null) > 1)
""");
});

public override Task Conditional_uncoalesce_with_inequality_right(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_uncoalesce_with_inequality_right(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((9 != c["Int"]) ? c["Int"] : null) > 1)
""");
});

#endregion Conditional

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.EntityFrameworkCore.Query.Translations;

public class OperatorTranslationsInMemoryTest(BasicTypesQueryInMemoryFixture fixture)
: OperatorTranslationsTestBase<BasicTypesQueryInMemoryFixture>(fixture);
Original file line number Diff line number Diff line change
Expand Up @@ -429,38 +429,4 @@ await AssertQuery(
}

#endregion

#region Uncoalescing conditional

// In relational providers, x == a ? null : x is translated to SQL NULLIF

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Uncoalescing_conditional_with_equality_left(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (x.Int == 9 ? null : x.Int) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Uncoalescing_conditional_with_equality_right(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (9 == x.Int ? null : x.Int) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Uncoalescing_conditional_with_unequality_left(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (x.Int != 9 ? x.Int : null) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Uncoalescing_conditional_with_inequality_right(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (9 != x.Int ? x.Int : null) > 1));

#endregion Uncoalescing conditional
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.TestModels.BasicTypesModel;

namespace Microsoft.EntityFrameworkCore.Query.Translations;

public abstract class OperatorTranslationsTestBase<TFixture>(TFixture fixture) : QueryTestBase<TFixture>(fixture)
where TFixture : BasicTypesQueryFixtureBase, new()
{
#region Conditional

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_simplifiable_equality(bool async)
=> AssertQuery(
async,
// ReSharper disable once MergeConditionalExpression
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int == null ? null : x.Int) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_simplifiable_inequality(bool async)
=> AssertQuery(
async,
// ReSharper disable once MergeConditionalExpression
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int != null ? x.Int : null) > 1));

// In relational providers, x == a ? null : x ("un-coalescing conditional") is translated to SQL NULLIF

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_uncoalesce_with_equality_left(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (x.Int == 9 ? null : x.Int) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_uncoalesce_with_equality_right(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (9 == x.Int ? null : x.Int) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_uncoalesce_with_unequality_left(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (x.Int != 9 ? x.Int : null) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_uncoalesce_with_inequality_right(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (9 != x.Int ? x.Int : null) > 1));

#endregion Conditional
}
Original file line number Diff line number Diff line change
Expand Up @@ -855,9 +855,7 @@ public override async Task Select_null_propagation_works_for_multiple_navigation

AssertSql(
"""
SELECT CASE
WHEN [c].[Name] IS NOT NULL THEN [c].[Name]
END
SELECT [c].[Name]
FROM [Tags] AS [t]
LEFT JOIN [Gears] AS [g] ON [t].[GearNickName] = [g].[Nickname] AND [t].[GearSquadId] = [g].[SquadId]
LEFT JOIN [Tags] AS [t0] ON ([g].[Nickname] = [t0].[GearNickName] OR ([g].[Nickname] IS NULL AND [t0].[GearNickName] IS NULL)) AND ([g].[SquadId] = [t0].[GearSquadId] OR ([g].[SquadId] IS NULL AND [t0].[GearSquadId] IS NULL))
Expand Down Expand Up @@ -3057,9 +3055,7 @@ public override async Task Select_null_conditional_with_inheritance(bool async)

AssertSql(
"""
SELECT CASE
WHEN [f].[CommanderName] IS NOT NULL THEN [f].[CommanderName]
END
SELECT [f].[CommanderName]
FROM [Factions] AS [f]
""");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ public NorthwindFunctionsQuerySqlServer160Test(Fixture160 fixture, ITestOutputHe
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());

public override async Task Client_evaluation_of_uncorrelated_method_call(bool async)
{
await base.Client_evaluation_of_uncorrelated_method_call(async);

AssertSql(
"""
SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice]
FROM [Order Details] AS [o]
WHERE [o].[UnitPrice] < 7.0 AND 10 < [o].[ProductID]
""");
}

public override async Task Sum_over_round_works_correctly_in_projection(bool async)
{
await base.Sum_over_round_works_correctly_in_projection(async);
Expand Down
Loading

0 comments on commit d546033

Please sign in to comment.