Square root of complex number

Numerical behaviour

import math

import numpy as np
import scipy as sc
import sympy as sp

import graphviz
np.sqrt(-1 + 0j)
1j
np.sqrt(-1)
/tmp/ipykernel_1929/3438155168.py:1: RuntimeWarning: invalid value encountered in sqrt
  np.sqrt(-1)
nan
math.sqrt(-1)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[4], line 1
----> 1 math.sqrt(-1)

ValueError: math domain error
sp.sqrt(4j)
\[\displaystyle 2.0 \sqrt{i}\]
np.emath.sqrt(-1)
1j
type(np.emath.sqrt(-1))
numpy.complex128
type(sp.sqrt(4j))
sympy.core.mul.Mul
sp.print_tree(sp.sqrt(4j), assumptions=False)
Mul: 2.0*sqrt(I)
+-Float: 2.00000000000000
+-Pow: sqrt(I)
  +-ImaginaryUnit: I
  +-Half: 1/2
complex(0, 4)
4j
sp.sympify(4j)
\[\displaystyle 4.0 i\]
sp.Integer(4) * sp.I
\[\displaystyle 4 i\]
4 * sp.I
\[\displaystyle 4 i\]

Getting control over lambdification

Example of Custom Printing Method

from sympy import Integer, Mod, Symbol, print_latex
# Always use printer._print()
class ModOp(Mod):
    def _latex(self, printer):
        a, b = [printer._print(i) for i in self.args]
        return r"\operatorname{Mod}{\left(%s, %s\right)}" % (a, b)
x = Symbol("x")
m = Symbol("m")
print_latex(Mod(x, m))
print_latex(ModOp(x, m))
x \bmod m
\operatorname{Mod}{\left(x, m\right)}
Mod(x, m)
\[\displaystyle x \bmod m\]
ModOp(x, m)
\[\displaystyle \operatorname{Mod}{\left(x, m\right)}\]

Custom SymPy expression class

class MyExpr(sp.Expr):
    def __new__(cls, var, **kwargs):
        var = sp.sympify(var)  # Convert to a SymPy expression if not already
        return sp.Expr.__new__(cls, var, **kwargs)

    def eval(self, **hints):
        return self.args[0] ** 2
 
    def doit(self, **hints):
        if hints.get("deep", True):
            terms = [
                term.doit(**hints) if isinstance(term, sp.Basic) else term
                for term in self.args
            ]
            return self.func(*terms).eval(**hints)
        else:
            return self.eval(**hints)

    def _latex(self, printer):
        return r"f\left(" + printer.doprint(self.args[0]) + r"\right)"

The __new__ method and _latex method are essential here to make (construct and print) custom SymPy expression class. And the custom replaced doit()method is the customed evaluation part.

x, y = sp.symbols("x,y")
expr = MyExpr(x * y)
expr
\[\displaystyle f\left(x y\right)\]
sp.print_tree(expr, assumptions=False)
MyExpr: MyExpr(x*y)
+-Mul: x*y
  +-Symbol: x
  +-Symbol: y
dot = sp.dotprint(expr)
graphviz.Source(dot)
_images/1eccfd552fbb90a6b51569d292c21924f14c86d8f2040903ab076a432ca1a368.svg
print("Original expression:", expr)
print("Doit output:", expr.doit())
print("LaTeX representation:", sp.latex(expr))
Original expression: MyExpr(x*y)
Doit output: x**2*y**2
LaTeX representation: f\left(x y\right)
expr.doit()
\[\displaystyle x^{2} y^{2}\]
expr.eval()
\[\displaystyle x^{2} y^{2}\]
sp.latex(expr)
'f\\left(x y\\right)'
import inspect

print(inspect.getsource(sp.Expr.doit))
    def doit(self, **hints):
        """Evaluate objects that are not evaluated by default like limits,
        integrals, sums and products. All objects of this kind will be
        evaluated recursively, unless some species were excluded via 'hints'
        or unless the 'deep' hint was set to 'False'.

        >>> from sympy import Integral
        >>> from sympy.abc import x

        >>> 2*Integral(x, x)
        2*Integral(x, x)

        >>> (2*Integral(x, x)).doit()
        x**2

        >>> (2*Integral(x, x)).doit(deep=False)
        2*Integral(x, x)

        """
        if hints.get('deep', True):
            terms = [term.doit(**hints) if isinstance(term, Basic) else term
                                         for term in self.args]
            return self.func(*terms)
        else:
            return self
expr2 = MyExpr(MyExpr(x * y))
expr2
\[\displaystyle f\left(f\left(x y\right)\right)\]
dot = sp.dotprint(expr2)
graphviz.Source(dot)
_images/55363a75d0d25691f0d0c8a8a36c44f1bc287a1509539c99f4552f7aa0bcea41.svg
expr2.doit()
\[\displaystyle x^{4} y^{4}\]
expr2.doit().doit()  # unnecessary for the original doit()
\[\displaystyle x^{4} y^{4}\]
expr2.eval()
\[\displaystyle f\left(x y\right)^{2}\]
sp.expand(expr2.eval())
\[\displaystyle f\left(x y\right)^{2}\]
expr3 = MyExpr(MyExpr(MyExpr(x * y)))
expr3
\[\displaystyle f\left(f\left(f\left(x y\right)\right)\right)\]
dot = sp.dotprint(expr3)
graphviz.Source(dot)
_images/0cca62d8ad23e4e9937df9511baac736343f8c18514b75ec91f7fe9a16e9538c.svg
expr3.doit()
\[\displaystyle x^{8} y^{8}\]
expr3.eval()
\[\displaystyle f\left(f\left(x y\right)\right)^{2}\]
n = sp.Symbol("n")
expr1 = sp.Sum(MyExpr(x) ** n, (n, 1, 3))
expr1
\[\displaystyle \sum_{n=1}^{3} f\left(x\right)^{n}\]
dot = sp.dotprint(expr1)
graphviz.Source(dot)
_images/92bd32f9c2d47f1511eff28493a112b4c30d5e95b1421ed174bb6f29ed66437a.svg
expr1.doit()
\[\displaystyle x^{6} + x^{4} + x^{2}\]
MyExpr(MyExpr(x)).doit()
\[\displaystyle x^{4}\]
MyExpr(MyExpr(x)).doit(deep=False)
\[\displaystyle f\left(x\right)^{2}\]
sp.print_tree(expr1.doit(), assumptions=False)
Add: x**6 + x**4 + x**2
+-Pow: x**2
| +-Symbol: x
| +-Integer: 2
+-Pow: x**4
| +-Symbol: x
| +-Integer: 4
+-Pow: x**6
  +-Symbol: x
  +-Integer: 6

Customed Printer

print(sp.latex(expr1))
\sum_{n=1}^{3} f\left(x\right)^{n}
from sympy.printing.latex import LatexPrinter

printer = LatexPrinter()
printer.doprint(expr1)
'\\sum_{n=1}^{3} f\\left(x\\right)^{n}'
class MyLatexPrinter(LatexPrinter):
    printmethod = '_Latex1'
    
    def _print_MyExpr(self,expr)->str:
        return r"g\left(" + self._print(expr.args[0]) + r"\right)"
        

printer = MyLatexPrinter()
printer.doprint(expr1)
    
'\\sum_{n=1}^{3} g\\left(x\\right)^{n}'