# Square root of complex number

## Numerical behaviour 

In [None]:
import math

import numpy as np
import scipy as sc
import sympy as sp

import graphviz

In [None]:
np.sqrt(-1 + 0j)

In [None]:
np.sqrt(-1)

In [None]:
math.sqrt(-1)

In [None]:
sp.sqrt(4j)

In [None]:
np.emath.sqrt(-1)

In [None]:
type(np.emath.sqrt(-1))

In [None]:
type(sp.sqrt(4j))

In [None]:
sp.print_tree(sp.sqrt(4j), assumptions=False)

In [None]:
complex(0, 4)

In [None]:
sp.sympify(4j)

In [None]:
sp.Integer(4) * sp.I

In [None]:
4 * sp.I

# Getting control over lambdification

## Example of Custom Printing Method

In [None]:
from sympy import Integer, Mod, Symbol, print_latex

In [None]:
# Always use printer._print()
class ModOp(Mod):
 def _latex(self, printer):
 a, b = [printer._print(i) for i in self.args]
 return r"\operatorname{Mod}{\left(%s, %s\right)}" % (a, b)

In [None]:
x = Symbol("x")
m = Symbol("m")
print_latex(Mod(x, m))
print_latex(ModOp(x, m))

In [None]:
Mod(x, m)

In [None]:
ModOp(x, m)

## Custom `SymPy` expression class

In [None]:
class MyExpr(sp.Expr):
 def __new__(cls, var, **kwargs):
 var = sp.sympify(var) # Convert to a SymPy expression if not already
 return sp.Expr.__new__(cls, var, **kwargs)

 def eval(self, **hints):
 return self.args[0] ** 2
 
 def doit(self, **hints):
 if hints.get("deep", True):
 terms = [
 term.doit(**hints) if isinstance(term, sp.Basic) else term
 for term in self.args
 ]
 return self.func(*terms).eval(**hints)
 else:
 return self.eval(**hints)

 def _latex(self, printer):
 return r"f\left(" + printer.doprint(self.args[0]) + r"\right)"

The `__new__` method
and 
`_latex` method
are essential here to make (construct and print) custom SymPy expression class.
And the custom replaced doit()method is the customed evaluation part.

In [None]:
x, y = sp.symbols("x,y")
expr = MyExpr(x * y)
expr

In [None]:
sp.print_tree(expr, assumptions=False)

In [None]:
dot = sp.dotprint(expr)
graphviz.Source(dot)

In [None]:
print("Original expression:", expr)
print("Doit output:", expr.doit())
print("LaTeX representation:", sp.latex(expr))

In [None]:
expr.doit()

In [None]:
expr.eval()

In [None]:
sp.latex(expr)

In [None]:
import inspect

print(inspect.getsource(sp.Expr.doit))

In [None]:
expr2 = MyExpr(MyExpr(x * y))
expr2

In [None]:
dot = sp.dotprint(expr2)
graphviz.Source(dot)

In [None]:
expr2.doit()

In [None]:
expr2.doit().doit() # unnecessary for the original doit()

In [None]:
expr2.eval()

In [None]:
sp.expand(expr2.eval())

In [None]:
expr3 = MyExpr(MyExpr(MyExpr(x * y)))
expr3

In [None]:
dot = sp.dotprint(expr3)
graphviz.Source(dot)

In [None]:
expr3.doit()

In [None]:
expr3.eval()

In [None]:
n = sp.Symbol("n")
expr1 = sp.Sum(MyExpr(x) ** n, (n, 1, 3))
expr1

In [None]:
dot = sp.dotprint(expr1)
graphviz.Source(dot)

In [None]:
expr1.doit()

In [None]:
MyExpr(MyExpr(x)).doit()

In [None]:
MyExpr(MyExpr(x)).doit(deep=False)

In [None]:
sp.print_tree(expr1.doit(), assumptions=False)

## Customed Printer

In [None]:
print(sp.latex(expr1))

In [None]:
from sympy.printing.latex import LatexPrinter

printer = LatexPrinter()
printer.doprint(expr1)

In [None]:
class MyLatexPrinter(LatexPrinter):
 printmethod = '_Latex1'
 
 def _print_MyExpr(self,expr)->str:
 return r"g\left(" + self._print(expr.args[0]) + r"\right)"
 

printer = MyLatexPrinter()
printer.doprint(expr1)
 

## Lamdification 