update
This commit is contained in:
parent
30de2d2269
commit
e995851232
|
@ -16,6 +16,7 @@ import matplotlib.pyplot as plt
|
||||||
import plotly.io as pio
|
import plotly.io as pio
|
||||||
import plotly.graph_objects as go
|
import plotly.graph_objects as go
|
||||||
import plotly.figure_factory as ff
|
import plotly.figure_factory as ff
|
||||||
|
from plotly.subplots import make_subplots
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
@ -82,7 +83,7 @@ class gd_1d(object):
|
||||||
gradient = df(xn)
|
gradient = df(xn)
|
||||||
self.xn_list.append(xn)
|
self.xn_list.append(xn)
|
||||||
self.df_list.append(gradient)
|
self.df_list.append(gradient)
|
||||||
if abs(gradient < self.wg_epsilon.value):
|
if (abs(gradient) < self.wg_epsilon.value):
|
||||||
clear_output(wait=True)
|
clear_output(wait=True)
|
||||||
print("Found solution of {} after".format(expr), n, "iterations")
|
print("Found solution of {} after".format(expr), n, "iterations")
|
||||||
print("x* = {}".format(xn))
|
print("x* = {}".format(xn))
|
||||||
|
@ -108,14 +109,14 @@ class gd_1d(object):
|
||||||
fig.add_scatter(x=xx1, y=fx)
|
fig.add_scatter(x=xx1, y=fx)
|
||||||
frames = []
|
frames = []
|
||||||
frames.append({'data':copy.deepcopy(fig['data']),'name':f'frame{0}'})
|
frames.append({'data':copy.deepcopy(fig['data']),'name':f'frame{0}'})
|
||||||
fig.add_traces(go.Scatter(x=None, y=None, mode="lines + markers", line={"color":"#de1032", "width":5}))
|
fig.add_traces(go.Scatter(x=None, y=None, mode="lines + markers", line={"color":"#de1032", "width":1, 'dash': 'dash'}))
|
||||||
frames = [go.Frame(data= [go.Scatter(x=np.array(self.xn_list)[:k], y=f_xn)],traces= [1],name=f'frame{k+2}')for k in range(len(f_xn))]
|
frames = [go.Frame(data= [go.Scatter(x=np.array(self.xn_list)[:k], y=f_xn)],traces= [1],name=f'frame{k+2}')for k in range(len(f_xn))]
|
||||||
fig.update(frames=frames)
|
fig.update(frames=frames)
|
||||||
fig.update_layout(updatemenus=[dict(type="buttons",buttons=[dict(label="Play",method="animate",args=[None])])])
|
fig.update_layout(height=800, updatemenus=[dict(type="buttons",buttons=[dict(label="Play",method="animate",args=[None, dict(fromcurrent=True, transition=dict(duration=0), frame=dict(redraw=True, duration=1000))])])])
|
||||||
fig.show()
|
fig.show()
|
||||||
|
|
||||||
|
|
||||||
class gd_2d(object):
|
class gd2d(object):
|
||||||
def __init__(self, environ:str="jupyterlab", type="default"):
|
def __init__(self, environ:str="jupyterlab", type="default"):
|
||||||
if type == "default":
|
if type == "default":
|
||||||
self.initialization_default(environ=environ)
|
self.initialization_default(environ=environ)
|
||||||
|
@ -217,7 +218,6 @@ class gd_2d(object):
|
||||||
proj_z = lambda x, y, z: z
|
proj_z = lambda x, y, z: z
|
||||||
colorsurfz = proj_z(xx1, xx2, fx)
|
colorsurfz = proj_z(xx1, xx2, fx)
|
||||||
|
|
||||||
from plotly.subplots import make_subplots
|
|
||||||
fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'surface'}, {'type': 'surface'}]])
|
fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'surface'}, {'type': 'surface'}]])
|
||||||
fig.add_trace(go.Surface(contours = {"x": {"show": True}, "y":{"show": True}, "z":{"show": True}}, x=xx1, y=xx2, z=fx), row=1, col=1)
|
fig.add_trace(go.Surface(contours = {"x": {"show": True}, "y":{"show": True}, "z":{"show": True}}, x=xx1, y=xx2, z=fx), row=1, col=1)
|
||||||
fig.add_trace(go.Scatter3d(x=None, y=None, z=None), row=1, col=1)
|
fig.add_trace(go.Scatter3d(x=None, y=None, z=None), row=1, col=1)
|
||||||
|
@ -226,11 +226,11 @@ class gd_2d(object):
|
||||||
fig.add_trace(go.Scatter3d(x=None, y=None, z=None), row=1, col=2)
|
fig.add_trace(go.Scatter3d(x=None, y=None, z=None), row=1, col=2)
|
||||||
fig.add_trace(go.Scatter3d(x=None, y=None, z=None), row=1, col=2)
|
fig.add_trace(go.Scatter3d(x=None, y=None, z=None), row=1, col=2)
|
||||||
frames = [go.Frame(data=[go.Surface(visible=True, showscale=False, opacity=0.8),
|
frames = [go.Frame(data=[go.Surface(visible=True, showscale=False, opacity=0.8),
|
||||||
go.Scatter3d(x=np.array(self.xn_list)[:k,0], y=np.array(self.xn_list)[:k,1], z=f_xn),
|
go.Scatter3d(x=np.array(self.xn_list)[:k,0], y=np.array(self.xn_list)[:k,1], z=f_xn, line={"color":"#10dedb", "width":3, 'dash': 'dash'}),
|
||||||
go.Surface(visible=False, x=xx1_tangent, y=xx2_tangent, z=z[k]),
|
go.Surface(visible=False, x=xx1_tangent, y=xx2_tangent, z=z[k]),
|
||||||
go.Surface(visible=True, showscale=False, opacity=0.8),
|
go.Surface(visible=True, showscale=False, opacity=0.8),
|
||||||
go.Scatter3d(x=np.array(self.xn_list)[:k, 0], y=np.array(self.xn_list)[:k, 1], z=f_xn),
|
go.Scatter3d(x=np.array(self.xn_list)[:k, 0], y=np.array(self.xn_list)[:k, 1], z=f_xn, line={"color":"#10dedb", "width":3, 'dash': 'dash'}),
|
||||||
go.Scatter3d(x=np.array(self.xn_list)[:k, 0].flatten(), y=np.array(self.xn_list)[:k, 1].flatten(), z=z_offset.flatten())],
|
go.Scatter3d(x=np.array(self.xn_list)[:k, 0].flatten(), y=np.array(self.xn_list)[:k, 1].flatten(), z=z_offset.flatten(), line={"color":"#58de10", "width":3, 'dash': 'dash'})],
|
||||||
traces=[0, 1, 2, 3, 4, 5]) for k in range(len(f_xn))]
|
traces=[0, 1, 2, 3, 4, 5]) for k in range(len(f_xn))]
|
||||||
fig.frames = frames
|
fig.frames = frames
|
||||||
self.fig_frames = frames
|
self.fig_frames = frames
|
||||||
|
@ -284,3 +284,83 @@ class gd_2d(object):
|
||||||
buttons=list([
|
buttons=list([
|
||||||
dict(args=[{"visible":["True", "True"]}], label="Gradient", method="update")]))], height=800)
|
dict(args=[{"visible":["True", "True"]}], label="Gradient", method="update")]))], height=800)
|
||||||
fig.show()
|
fig.show()
|
||||||
|
|
||||||
|
|
||||||
|
class gd2d_compete(object):
|
||||||
|
def __init__(self, environ:str="jupyterlab"):
|
||||||
|
self.initialization(environ=environ)
|
||||||
|
|
||||||
|
|
||||||
|
def initialization(self, environ):
|
||||||
|
pio.renderers.default = environ # 'notebook' or 'colab' or 'jupyterlab'
|
||||||
|
self.timer = 0
|
||||||
|
self.wg_expr = widgets.Dropdown(options=[("(1 - 8 * x1 + 7 * x1**2 - (7/3) * x1**3 + (1/4) * x1**4) * x2**2 * E**(-x2)", "(1 - 8 * x1 + 7 * x1**2 - (7/3) * x1**3 + (1/4) * x1**4) * x2**2 * E**(-x2)"), ("(sin(x1) - 2) ** 2 + (sin(x2) - 2) ** 2", "(sin(x1) - 2) ** 2 + (sin(x2) - 2) ** 2")], value="(1 - 8 * x1 + 7 * x1**2 - (7/3) * x1**3 + (1/4) * x1**4) * x2**2 * E**(-x2)", descrption="Expression")
|
||||||
|
self.wg_x0 = widgets.Text(value="0,2", description="Init point:")
|
||||||
|
self.wg_lr = widgets.FloatText(value="0.1", description="step size:")
|
||||||
|
self.wg_direction_p0 = widgets.Text(value="0.5,1", description="Direction (a1)")
|
||||||
|
self.wg_direction_p1 = widgets.Text(value="2.5,1", description="Direction (a2)")
|
||||||
|
# need learning rate
|
||||||
|
self.button_compute = widgets.Button(description="Compute")
|
||||||
|
self.button_plot = widgets.Button(description="Plot")
|
||||||
|
self.compute_output = widgets.Output()
|
||||||
|
self.plot_output =widgets.Output()
|
||||||
|
|
||||||
|
self.params_lvbox = widgets.VBox([self.wg_x0, self.wg_lr])
|
||||||
|
self.params_rvbox = widgets.VBox([self.wg_direction_p0, self.wg_direction_p1])
|
||||||
|
self.exp_box = widgets.HBox([self.wg_expr])
|
||||||
|
self.params_box = widgets.HBox([self.params_lvbox, self.params_rvbox], description="Parameters")
|
||||||
|
self.button_box = widgets.HBox([self.button_compute, self.button_plot])
|
||||||
|
self.config = widgets.VBox([self.exp_box, self.params_box, self.button_box])
|
||||||
|
display(self.config)
|
||||||
|
|
||||||
|
self.button_compute.on_click(self.compute)
|
||||||
|
display(self.compute_output)
|
||||||
|
self.button_plot.on_click(self.plot)
|
||||||
|
display(self.plot_output)
|
||||||
|
|
||||||
|
def compute(self, *args):
|
||||||
|
with self.compute_output:
|
||||||
|
expr = sympify(self.wg_expr.value)
|
||||||
|
if self.timer == 0:
|
||||||
|
self.xn_p0_list, self.xn_p1_list = [], []
|
||||||
|
x0 = np.array(self.wg_x0.value.split(","), dtype=float)
|
||||||
|
self.xn_p0_list.append(x0)
|
||||||
|
self.xn_p1_list.append(x0)
|
||||||
|
direction_p0 = np.array(self.wg_direction_p0.value.split(","), dtype=float)
|
||||||
|
direction_p1 = np.array(self.wg_direction_p1.value.split(","), dtype=float)
|
||||||
|
self.timer = self.timer + 1
|
||||||
|
# calcualte next point position
|
||||||
|
x0_p0 = self.xn_p0_list[self.timer-1] + self.wg_lr.value * direction_p0
|
||||||
|
x0_p1 = self.xn_p1_list[self.timer-1] + self.wg_lr.value * direction_p1
|
||||||
|
self.xn_p0_list.append(x0_p0)
|
||||||
|
self.xn_p1_list.append(x0_p1)
|
||||||
|
clear_output(wait=True)
|
||||||
|
print("a1({}): {}, a2({}): {}".format(self.timer, x0_p0, self.timer, x0_p1))
|
||||||
|
|
||||||
|
def plot(self, *args):
|
||||||
|
with self.plot_output:
|
||||||
|
clear_output(wait=True)
|
||||||
|
x1, x2 =symbols("x1 x2")
|
||||||
|
xx1, xx2 = np.arange(0, 5, 0.25), np.arange(0, 5, 0.25)
|
||||||
|
xx1, xx2 = np.meshgrid(xx1, xx2)
|
||||||
|
func = lambdify((x1, x2), sympify(self.wg_expr.value), "numpy")
|
||||||
|
fx = func(xx1, xx2)
|
||||||
|
fx_p0 = func(np.array(self.xn_p0_list)[:, 0], np.array(self.xn_p0_list)[:, 1])
|
||||||
|
fx_p1 = func(np.array(self.xn_p1_list)[:, 0], np.array(self.xn_p1_list)[:, 1])
|
||||||
|
|
||||||
|
#TODO: compute gradient
|
||||||
|
|
||||||
|
fig = make_subplots(rows=1, cols=1, specs=[[{'type': 'surface'}]])
|
||||||
|
fig.add_trace(go.Surface(x=xx1, y=xx2, z=fx), row=1, col=1)
|
||||||
|
fig.add_trace(go.Scatter3d(x=np.array(self.xn_p0_list)[:, 0], y=np.array(self.xn_p0_list)[:, 1], z=fx_p0,
|
||||||
|
name="candidate 1", mode="lines+markers", marker=dict(color="green")), row=1, col=1)
|
||||||
|
fig.add_trace(go.Scatter3d(x=np.array(self.xn_p1_list)[:, 0], y=np.array(self.xn_p1_list)[:, 1], z=fx_p1,
|
||||||
|
name="candidate 2", mode="lines+markers", marker=dict(color="blue")), row=1, col=1)
|
||||||
|
frames = [go.Frame(data = [go.Surface(visible=True, showscale=False, opacity=0.8),
|
||||||
|
go.Scatter3d(x=np.array(self.xn_p0_list)[:self.timer, 0], y=np.array(self.xn_p0_list)[:self.timer, 1], z=fx_p0),
|
||||||
|
go.Scatter3d(x=np.array(self.xn_p1_list)[:self.timer, 0], y=np.array(self.xn_p1_list)[:self.timer, 1], z=fx_p1)],
|
||||||
|
traces=[0,1,2])]
|
||||||
|
fig.frames = frames
|
||||||
|
fig.update_layout(scene_aspectmode='manual', scene_aspectratio=dict(x=0, y=0, z=0), height=800)
|
||||||
|
fig.update_layout(legend=dict(yanchor="auto", y=0.9, xanchor="left", x=0.4))
|
||||||
|
fig.show()
|
Loading…
Reference in New Issue