Total derivatives in sympy

This is a quick one. I'd like to use sympy to derive the the total derivative of an expression symbolically, not just a partial derivative. sympy has straightforward support for partial derivatives, but I was confused as to total derivatives.

For example, I'd like to say

difftotal(x_dot**2 + cos(x), t)

and get back

-g*x_dot*sin(x) + 2*x_ddot*x_dot

The difficulty is specifying which variables are actually functions of t, and which are constants (g, in this case). Also, you want to specify that dx/dt is the variable x_dot, which also appears in your original equation.

Here's a function to do just that:

from sympy import *

def difftotal(expr, diffby, diffmap):
    """Take the total derivative with respect to a variable.

    Example:

        theta, t, theta_dot = symbols("theta t theta_dot")
        difftotal(cos(theta), t, {theta: theta_dot})

    returns

        -theta_dot*sin(theta)
    """
    # Replace all symbols in the diffmap by a functional form
    fnexpr = expr.subs({s:s(diffby) for s in diffmap})
    # Do the differentiation
    diffexpr = diff(fnexpr, diffby)
    # Replace the Derivatives with the variables in diffmap
    derivmap = {Derivative(v(diffby), diffby):dv 
                for v,dv in diffmap.iteritems()}
    finaldiff = diffexpr.subs(derivmap)
    # Replace the functional forms with their original form
    return finaldiff.subs({s(diffby):s for s in diffmap})

Now you can say:

>>> from sympy import *
>>> x_dot, x_ddot, t, x, g =  symbols("x_dot x_ddot t x g")
>>> diffresult = difftotal(x_dot**2 + g*cos(x), t, {x: x_dot, x_dot: x_ddot})
>>> print diffresult
-g*x_dot*sin(x) + 2*x_ddot*x_dot

And to clean up:

>>> simplify(diffresult)
x_dot*(-g*sin(x) + 2*x_ddot)

A gotcha for new sympy users - the sin and cos above are from sympy! (from sympy import *) You can't use the default sin and cos with sympy.Symbol and have it work.

Tagged , , ,