生成多项式评估的方法

我试图想出一种优雅的方法来处理一些生成的多项式。 以下是我们将(专门)关注此问题的情况:

  1. order是生成n阶多项式的参数,其中n:= order + 1。
  2. i是0..n范围内的整数参数
  3. 多项式在x_j处有零,其中j = 1..n且j≠i(此时应该清楚StackOverflow需要一个新特性或它存在并且我不知道它)
  4. 多项式在x_i处求值为1。

由于此特定代码示例生成x_1 .. x_n,我将解释它们是如何在代码中找到的。 这些点均匀间隔x_j = j * elementSize / order apart,其中n = order + 1

我生成一个Func来评估这个多项式¹。

 private static Func GeneratePsi(double elementSize, int order, int i) { if (order < 1) throw new ArgumentOutOfRangeException("order", "order must be greater than 0."); if (i  order) throw new ArgumentException("i", "i cannot be greater than order"); ParameterExpression xp = Expression.Parameter(typeof(double), "x"); // generate the terms of the factored polynomial in form (x_j - x) List factors = new List(); for (int j = 0; j  product * (j == i ? 1.0 : (j * elementSize / order - xi))); /* generate an expression to evaluate * (x_0 - x) * (x_1 - x) .. (x_n - x) / (x_i - x) * obviously the term (x_i - x) is cancelled in this result, but included here to make the result clear */ Expression expr = factors.Skip(1).Aggregate(factors[0], Expression.Multiply); // multiplying by scale forces the condition f(x_i)=1 expr = Expression.Multiply(Expression.Constant(1.0 / scaleInv), expr); Expression<Func> lambdaMethod = Expression.Lambda<Func>(expr, xp); return lambdaMethod.Compile(); } 

问题:我还需要评估ψ’=dψ/ dx。 为此,我可以用ψ=α_n×x ^ n +α_n×x的forms重写ψ= scale×(x_0 – x)(x_1 – x)×..×(x_n – x)/(x_i – x) ^(n-1)+ .. +α_1×x +α_0。 这给出ψ’= n×α_n×x ^(n-1)+(n-1)×α_n×x ^(n-2)+ .. + 1×α_1。

出于计算原因,我们可以通过写ψ’= x×(x×(x×(..) – β_2) – β_1) – β_0来重写最终答案,而无需调用Math.Pow。

要做所有这些“诡计”(所有非常基本的代数),我需要一个干净的方法:

  1. 展开一个包含ConstantExpressionParameterExpression Expression离开和基本的数学运算(最终将BinaryExpressionNodeType设置为操作) – 这里的结果可以包含InvocationExpression元素到Math.Pow for Math.Pow ,我们将以特殊方式处理始终。
  2. 然后我对某些指定的ParameterExpression采用导数。 结果中的术语,其中Math.Pow调用的右侧参数是常量2,由ConstantExpression(2)乘以左侧( Math.Pow(x,1)的调用) Math.Pow(x,1)去除)。 结果中因为它们相对于x不变而变为零的项被删除。
  3. 然后将一些特定ParameterExpression的实例Math.Pow出来,它们作为Math.Pow调用的左侧参数Math.Pow 。 当调用的右侧变为值为1ConstantExpression时,我们仅使用ParameterExpression本身替换调用。

¹将来,我希望采用ParameterExpression并返回一个基于该参数进行求值的Expression的方法。 这样我就可以聚合生成的函数。 我还没有。 ²将来,我希望发布一个使用LINQ表达式作为符号数学的通用库。

我使用.NET 4中的ExpressionVisitor类型编写了几个符号数学特性的基础知识。它并不完美,但它看起来像是一个可行解决方案的基础。

  • Symbolic是一个暴露公共静态类的方法,如ExpandSimplifyPartialDerivative
  • ExpandVisitor是一种扩展表达式的内部帮助器类型
  • SimplifyVisitor是一种内部帮助器类型,可简化表达式
  • DerivativeVisitor是一个内部帮助器类型,它采用表达式的导数
  • ListPrintVisitor是一种内部帮助器类型,它将Expression转换为带有Lisp语法的前缀表示法

Symbolic

 public static class Symbolic { public static Expression Expand(Expression expression) { return new ExpandVisitor().Visit(expression); } public static Expression Simplify(Expression expression) { return new SimplifyVisitor().Visit(expression); } public static Expression PartialDerivative(Expression expression, ParameterExpression parameter) { bool totalDerivative = false; return new DerivativeVisitor(parameter, totalDerivative).Visit(expression); } public static string ToString(Expression expression) { ConstantExpression result = (ConstantExpression)new ListPrintVisitor().Visit(expression); return result.Value.ToString(); } } 

使用ExpandVisitor扩展表达式

 internal class ExpandVisitor : ExpressionVisitor { protected override Expression VisitBinary(BinaryExpression node) { var left = Visit(node.Left); var right = Visit(node.Right); if (node.NodeType == ExpressionType.Multiply) { Expression[] leftNodes = GetAddedNodes(left).ToArray(); Expression[] rightNodes = GetAddedNodes(right).ToArray(); var result = leftNodes .SelectMany(x => rightNodes.Select(y => Expression.Multiply(x, y))) .Aggregate((sum, term) => Expression.Add(sum, term)); return result; } if (node.Left == left && node.Right == right) return node; return Expression.MakeBinary(node.NodeType, left, right, node.IsLiftedToNull, node.Method, node.Conversion); } ///  /// Treats the  as the sum (or difference) of one or more child nodes and returns the /// the individual addends in the sum. ///  private static IEnumerable GetAddedNodes(Expression node) { BinaryExpression binary = node as BinaryExpression; if (binary != null) { switch (binary.NodeType) { case ExpressionType.Add: foreach (var n in GetAddedNodes(binary.Left)) yield return n; foreach (var n in GetAddedNodes(binary.Right)) yield return n; yield break; case ExpressionType.Subtract: foreach (var n in GetAddedNodes(binary.Left)) yield return n; foreach (var n in GetAddedNodes(binary.Right)) yield return Expression.Negate(n); yield break; default: break; } } yield return node; } } 

使用DerivativeVisitor

 internal class DerivativeVisitor : ExpressionVisitor { private ParameterExpression _parameter; private bool _totalDerivative; public DerivativeVisitor(ParameterExpression parameter, bool totalDerivative) { if (_totalDerivative) throw new NotImplementedException(); _parameter = parameter; _totalDerivative = totalDerivative; } protected override Expression VisitBinary(BinaryExpression node) { switch (node.NodeType) { case ExpressionType.Add: case ExpressionType.Subtract: return Expression.MakeBinary(node.NodeType, Visit(node.Left), Visit(node.Right)); case ExpressionType.Multiply: return Expression.Add(Expression.Multiply(node.Left, Visit(node.Right)), Expression.Multiply(Visit(node.Left), node.Right)); case ExpressionType.Divide: return Expression.Divide(Expression.Subtract(Expression.Multiply(Visit(node.Left), node.Right), Expression.Multiply(node.Left, Visit(node.Right))), Expression.Power(node.Right, Expression.Constant(2))); case ExpressionType.Power: if (node.Right is ConstantExpression) { return Expression.Multiply(node.Right, Expression.Multiply(Visit(node.Left), Expression.Subtract(node.Right, Expression.Constant(1)))); } else if (node.Left is ConstantExpression) { return Expression.Multiply(node, MathExpressions.Log(node.Left)); } else { return Expression.Multiply(node, Expression.Add( Expression.Multiply(Visit(node.Left), Expression.Divide(node.Right, node.Left)), Expression.Multiply(Visit(node.Right), MathExpressions.Log(node.Left)) )); } default: throw new NotImplementedException(); } } protected override Expression VisitConstant(ConstantExpression node) { return MathExpressions.Zero; } protected override Expression VisitInvocation(InvocationExpression node) { MemberExpression memberExpression = node.Expression as MemberExpression; if (memberExpression != null) { var member = memberExpression.Member; if (member.DeclaringType != typeof(Math)) throw new NotImplementedException(); switch (member.Name) { case "Log": return Expression.Divide(Visit(node.Expression), node.Expression); case "Log10": return Expression.Divide(Visit(node.Expression), Expression.Multiply(Expression.Constant(Math.Log(10)), node.Expression)); case "Exp": case "Sin": case "Cos": default: throw new NotImplementedException(); } } throw new NotImplementedException(); } protected override Expression VisitParameter(ParameterExpression node) { if (node == _parameter) return MathExpressions.One; return MathExpressions.Zero; } } 

使用SimplifyVisitor简化表达式

 internal class SimplifyVisitor : ExpressionVisitor { protected override Expression VisitBinary(BinaryExpression node) { var left = Visit(node.Left); var right = Visit(node.Right); ConstantExpression leftConstant = left as ConstantExpression; ConstantExpression rightConstant = right as ConstantExpression; if (leftConstant != null && rightConstant != null && (leftConstant.Value is double) && (rightConstant.Value is double)) { double leftValue = (double)leftConstant.Value; double rightValue = (double)rightConstant.Value; switch (node.NodeType) { case ExpressionType.Add: return Expression.Constant(leftValue + rightValue); case ExpressionType.Subtract: return Expression.Constant(leftValue - rightValue); case ExpressionType.Multiply: return Expression.Constant(leftValue * rightValue); case ExpressionType.Divide: return Expression.Constant(leftValue / rightValue); default: throw new NotImplementedException(); } } switch (node.NodeType) { case ExpressionType.Add: if (IsZero(left)) return right; if (IsZero(right)) return left; break; case ExpressionType.Subtract: if (IsZero(left)) return Expression.Negate(right); if (IsZero(right)) return left; break; case ExpressionType.Multiply: if (IsZero(left) || IsZero(right)) return MathExpressions.Zero; if (IsOne(left)) return right; if (IsOne(right)) return left; break; case ExpressionType.Divide: if (IsZero(right)) throw new DivideByZeroException(); if (IsZero(left)) return MathExpressions.Zero; if (IsOne(right)) return left; break; default: throw new NotImplementedException(); } return Expression.MakeBinary(node.NodeType, left, right); } protected override Expression VisitUnary(UnaryExpression node) { var operand = Visit(node.Operand); ConstantExpression operandConstant = operand as ConstantExpression; if (operandConstant != null && (operandConstant.Value is double)) { double operandValue = (double)operandConstant.Value; switch (node.NodeType) { case ExpressionType.Negate: if (operandValue == 0.0) return MathExpressions.Zero; return Expression.Constant(-operandValue); default: throw new NotImplementedException(); } } switch (node.NodeType) { case ExpressionType.Negate: if (operand.NodeType == ExpressionType.Negate) { return ((UnaryExpression)operand).Operand; } break; default: throw new NotImplementedException(); } return Expression.MakeUnary(node.NodeType, operand, node.Type); } private static bool IsZero(Expression expression) { ConstantExpression constant = expression as ConstantExpression; if (constant != null) { if (constant.Value.Equals(0.0)) return true; } return false; } private static bool IsOne(Expression expression) { ConstantExpression constant = expression as ConstantExpression; if (constant != null) { if (constant.Value.Equals(1.0)) return true; } return false; } } 

使用ListPrintVisitor格式化显示ListPrintVisitor

 internal class ListPrintVisitor : ExpressionVisitor { protected override Expression VisitBinary(BinaryExpression node) { string op = null; switch (node.NodeType) { case ExpressionType.Add: op = "+"; break; case ExpressionType.Subtract: op = "-"; break; case ExpressionType.Multiply: op = "*"; break; case ExpressionType.Divide: op = "/"; break; default: throw new NotImplementedException(); } var left = Visit(node.Left); var right = Visit(node.Right); string result = string.Format("({0} {1} {2})", op, ((ConstantExpression)left).Value, ((ConstantExpression)right).Value); return Expression.Constant(result); } protected override Expression VisitConstant(ConstantExpression node) { if (node.Value is string) return node; return Expression.Constant(node.Value.ToString()); } protected override Expression VisitParameter(ParameterExpression node) { return Expression.Constant(node.Name); } } 

测试结果

 [TestMethod] public void BasicSymbolicTest() { ParameterExpression x = Expression.Parameter(typeof(double), "x"); Expression linear = Expression.Add(Expression.Constant(3.0), x); Assert.AreEqual("(+ 3 x)", Symbolic.ToString(linear)); Expression quadratic = Expression.Multiply(linear, Expression.Add(Expression.Constant(2.0), x)); Assert.AreEqual("(* (+ 3 x) (+ 2 x))", Symbolic.ToString(quadratic)); Expression expanded = Symbolic.Expand(quadratic); Assert.AreEqual("(+ (+ (+ (* 3 2) (* 3 x)) (* x 2)) (* xx))", Symbolic.ToString(expanded)); Assert.AreEqual("(+ (+ (+ 6 (* 3 x)) (* x 2)) (* xx))", Symbolic.ToString(Symbolic.Simplify(expanded))); Expression derivative = Symbolic.PartialDerivative(expanded, x); Assert.AreEqual("(+ (+ (+ (+ (* 3 0) (* 0 2)) (+ (* 3 1) (* 0 x))) (+ (* x 0) (* 1 2))) (+ (* x 1) (* 1 x)))", Symbolic.ToString(derivative)); Expression simplified = Symbolic.Simplify(derivative); Assert.AreEqual("(+ 5 (+ xx))", Symbolic.ToString(simplified)); }