Update 'algorithm/optimization/gradient.py'

This commit is contained in:
TerenceLiu 2023-02-19 12:50:29 +08:00
parent dba617b874
commit c48bfd2e8c
1 changed files with 4 additions and 1 deletions

View File

@ -364,7 +364,10 @@ class gd2d_compete(object):
with self.plot_output:
clear_output(wait=True)
x1, x2 =symbols("x1 x2")
xx1, xx2 = np.arange(0, 5, 0.25), np.arange(0, 5, 0.25)
if self.wg_expr.value == "(1 - 8 * x1 + 7 * x1**2 - (7/3) * x1**3 + (1/4) * x1**4) * x2**2 * E**(-x2)":
xx1, xx2 = np.arange(0, 5, 0.25), np.arange(0, 5, 0.25)
else:
xx1, xx2 = np.arange(-5, 5, 0.25), np.arange(-5, 5, 0.25)
xx1, xx2 = np.meshgrid(xx1, xx2)
func = lambdify((x1, x2), sympify(self.wg_expr.value), "numpy")
fx = func(xx1, xx2)