update
This commit is contained in:
parent
2fb99563ee
commit
0bccf62d54
|
@ -0,0 +1 @@
|
||||||
|
test/
|
|
@ -0,0 +1,75 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "af1d5646-4b13-4039-9f76-8042bc9dbda3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from optimization.gd import gradient_descent_1d, gradient_descent_2d"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0a5c51d2-8b18-4143-b6f1-73a909ccb623",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"gd2 = gradient_descent_2d()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "f254549d-d9b0-43ac-a502-e65b7c462d05",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"plotly.graph_objs._scatter3d.Scatter3d"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import plotly.graph_objects as go\n",
|
||||||
|
"go.Scatter3d"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8c0f6766-4f41-419b-94a0-916a379ebcd1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
Binary file not shown.
|
@ -0,0 +1,149 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "af1d5646-4b13-4039-9f76-8042bc9dbda3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from optimization.gd import gradient_descent_1d, gradient_descent_2d"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "0a5c51d2-8b18-4143-b6f1-73a909ccb623",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "923a9d77eecb4d78ade9e3113cf636c0",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"VBox(children=(HBox(children=(VBox(children=(Text(value='x**3 - x**(1/2)', description='Expression:', style=De…"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "c2252623a38940f0a80471ec28e491d3",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Output()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "778d4e34c4ae4b0d8f2f1b4e0d3c1173",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Output()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"gd1 = gradient_descent_1d()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "15c6e757-cde3-422b-be7e-3f55b7752142",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "298df33c058b49e6927e582d532bb87b",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"VBox(children=(HBox(children=(VBox(children=(Text(value='(sin(x1) - 2) ** 2 + (sin(x2) - 2) ** 2', description…"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "47315c498b9a46fd9948fd7a53d3e80a",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Output()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "0cab52875285496a989f685de24515cb",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Output()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"gd2 = gradient_descent_2d()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "120ecde5-cee1-47c0-b1bc-c99b0cb1098e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
|
@ -0,0 +1,202 @@
|
||||||
|
import copy
|
||||||
|
import time
|
||||||
|
import sympy
|
||||||
|
import numpy as np
|
||||||
|
from scipy.misc import derivative
|
||||||
|
from sympy import symbols, sympify, lambdify, diff
|
||||||
|
|
||||||
|
import ipywidgets as widgets
|
||||||
|
from IPython.display import display, clear_output
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
import plotly.io as pio
|
||||||
|
pio.renderers.default = 'iframe' # or 'notebook' or 'colab' or 'jupyterlab'
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
class gradient_descent_1d(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.wg_expr = widgets.Text(value="x**3 - x**(1/2)",
|
||||||
|
description="Expression:",
|
||||||
|
style={'description_width': 'initial'})
|
||||||
|
self.wg_x0 = widgets.FloatText(value="2",
|
||||||
|
description="Startpoint:",
|
||||||
|
style={'description_width': 'initial'})
|
||||||
|
self.wg_lr = widgets.FloatText(value="1e-1",
|
||||||
|
description="learning rate:",
|
||||||
|
style={'description_width': 'initial'})
|
||||||
|
self.wg_epsilon = widgets.FloatText(value="1e-5",
|
||||||
|
description="criterion:",
|
||||||
|
style={'description_width': 'initial'})
|
||||||
|
self.wg_max_iter = widgets.IntText(value="1000",
|
||||||
|
description="max iteration",
|
||||||
|
style={'description_width': 'initial'})
|
||||||
|
|
||||||
|
self.button_compute = widgets.Button(description="Compute")
|
||||||
|
self.button_plot = widgets.Button(description="Plot")
|
||||||
|
|
||||||
|
self.compute_output = widgets.Output()
|
||||||
|
self.plot_output = widgets.Output()
|
||||||
|
self.params_lvbox = widgets.VBox([self.wg_expr, self.wg_x0, self.wg_lr])
|
||||||
|
self.params_rvbox = widgets.VBox([self.wg_epsilon, self.wg_max_iter])
|
||||||
|
self.params_box = widgets.HBox([self.params_lvbox, self.params_rvbox], description="Parameters")
|
||||||
|
self.button_box = widgets.HBox([self.button_compute, self.button_plot], description="operations")
|
||||||
|
self.config = widgets.VBox([self.params_box, self.button_box],
|
||||||
|
layout=widgets.Layout(
|
||||||
|
display='flex',
|
||||||
|
flex_flow='column',
|
||||||
|
border='solid 2px',
|
||||||
|
align_items='stretch',
|
||||||
|
width='auto'
|
||||||
|
))
|
||||||
|
self.initialization()
|
||||||
|
|
||||||
|
|
||||||
|
def initialization(self):
|
||||||
|
display(self.config)
|
||||||
|
self.button_compute.on_click(self.compute)
|
||||||
|
display(self.compute_output)
|
||||||
|
self.button_plot.on_click(self.plot)
|
||||||
|
display(self.plot_output)
|
||||||
|
|
||||||
|
def compute(self, *args):
|
||||||
|
with self.compute_output:
|
||||||
|
xn = self.wg_x0.value
|
||||||
|
x = symbols("x")
|
||||||
|
expr = sympify(self.wg_expr.value)
|
||||||
|
f = lambdify(x, expr)
|
||||||
|
df = lambdify(x, diff(expr, x))
|
||||||
|
self.xn_list, self.df_list = [], []
|
||||||
|
|
||||||
|
for n in tqdm(range(0, self.wg_max_iter.value)):
|
||||||
|
gradient = df(xn)
|
||||||
|
self.xn_list.append(xn)
|
||||||
|
self.df_list.append(gradient)
|
||||||
|
if abs (gradient < self.wg_epsilon.value):
|
||||||
|
clear_output(wait=True)
|
||||||
|
print("Found solution of {} after".format(expr), n, "iterations")
|
||||||
|
print("x* = {}".format(xn))
|
||||||
|
return None
|
||||||
|
xn = xn - self.wg_lr.value * gradient
|
||||||
|
clear_output(wait=True)
|
||||||
|
display("Exceeded maximum iterations. No solution found.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def plot(self, *args):
|
||||||
|
with self.plot_output:
|
||||||
|
clear_output(wait=True)
|
||||||
|
x0 = float(self.wg_x0.value)
|
||||||
|
x = symbols("x")
|
||||||
|
expr = sympify(self.wg_expr.value)
|
||||||
|
f = lambdify(x, sympify(expr), "numpy")
|
||||||
|
xx1 = np.arange(np.array(self.xn_list).min()*0.5, np.array(self.xn_list).max()*1.5, 0.05)
|
||||||
|
fx = f(xx1)
|
||||||
|
f_xn = f(np.array(self.xn_list))
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
fig.add_scatter(x=xx1, y=fx)
|
||||||
|
frames = []
|
||||||
|
frames.append({'data':copy.deepcopy(fig['data']),'name':f'frame{0}'})
|
||||||
|
fig.add_traces(go.Scatter(x=None, y=None, mode="lines + markers", line={"color":"#de1032", "width":5}))
|
||||||
|
frames = [go.Frame(data= [go.Scatter(x=np.array(self.xn_list)[:k], y=f_xn)],traces= [1],name=f'frame{k+2}')for k in range(len(f_xn))]
|
||||||
|
fig.update(frames=frames)
|
||||||
|
fig.update_layout(updatemenus=[dict(type="buttons",buttons=[dict(label="Play",method="animate",args=[None])])])
|
||||||
|
fig.show()
|
||||||
|
|
||||||
|
|
||||||
|
class gradient_descent_2d(object):
|
||||||
|
def __init__(self):
|
||||||
|
|
||||||
|
self.wg_expr = widgets.Text(value="(sin(x1) - 2) ** 2 + (sin(x2) - 2) ** 2",
|
||||||
|
description="Expression:") #style={'description_width': 'initial'})
|
||||||
|
self.wg_x0 = widgets.Text(value="5,5",
|
||||||
|
description="Startpoint:")
|
||||||
|
self.wg_lr = widgets.FloatText(value="1e-1",
|
||||||
|
description="learning rate:")
|
||||||
|
self.wg_epsilon = widgets.FloatText(value="1e-5",
|
||||||
|
description="criterion:")
|
||||||
|
self.wg_max_iter = widgets.IntText(value="1000",
|
||||||
|
description="max iteration")
|
||||||
|
self.button_compute = widgets.Button(description="Compute")
|
||||||
|
self.button_plot = widgets.Button(description="Plot")
|
||||||
|
|
||||||
|
self.compute_output = widgets.Output()
|
||||||
|
self.plot_output = widgets.Output()
|
||||||
|
self.params_lvbox = widgets.VBox([self.wg_expr, self.wg_x0, self.wg_lr])
|
||||||
|
self.params_rvbox = widgets.VBox([self.wg_epsilon, self.wg_max_iter])
|
||||||
|
self.params_box = widgets.HBox([self.params_lvbox, self.params_rvbox], description="Parameters")
|
||||||
|
self.button_box = widgets.HBox([self.button_compute, self.button_plot], description="operations")
|
||||||
|
self.config = widgets.VBox([self.params_box, self.button_box])
|
||||||
|
self.initialization()
|
||||||
|
|
||||||
|
def initialization(self):
|
||||||
|
display(self.config)
|
||||||
|
self.button_compute.on_click(self.compute)
|
||||||
|
display(self.compute_output)
|
||||||
|
self.button_plot.on_click(self.plot)
|
||||||
|
display(self.plot_output)
|
||||||
|
|
||||||
|
def compute(self, *args):
|
||||||
|
with self.compute_output:
|
||||||
|
x0 = np.array(self.wg_x0.value.split(","), dtype=float)
|
||||||
|
xn = x0
|
||||||
|
x1 = symbols("x1")
|
||||||
|
x2 = symbols("x2")
|
||||||
|
expr = sympify(self.wg_expr.value)
|
||||||
|
self.xn_list, self.df_list = [], []
|
||||||
|
|
||||||
|
for n in tqdm(range(0, self.wg_max_iter.value)):
|
||||||
|
gradient = np.array([diff(expr, x1).subs(x1, xn[0]).subs(x2, xn[1]),
|
||||||
|
diff(expr, x2).subs(x1, xn[0]).subs(x2, xn[1])], dtype=float)
|
||||||
|
self.xn_list.append(xn)
|
||||||
|
self.df_list.append(gradient)
|
||||||
|
if np.linalg.norm(gradient, ord=2) < self.wg_epsilon.value:
|
||||||
|
clear_output(wait=True)
|
||||||
|
print("Found solution of {} after".format(expr), n, "iterations")
|
||||||
|
print("x* = [{}, {}]".format(xn[0], xn[1]))
|
||||||
|
return None
|
||||||
|
xn = xn - self.wg_lr.value * gradient
|
||||||
|
clear_output(wait=True)
|
||||||
|
display("Exceeded maximum iterations. No solution found.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def plot(self, *args):
|
||||||
|
with self.plot_output:
|
||||||
|
clear_output(wait=True)
|
||||||
|
x0 = np.array(self.wg_x0.value.split(","), dtype=float)
|
||||||
|
x1 = symbols("x1")
|
||||||
|
x2 = symbols("x2")
|
||||||
|
expr = sympify(self.wg_expr.value)
|
||||||
|
xx1 = np.arange(np.array(self.xn_list)[:, 0].min()*0.5, np.array(self.xn_list)[:, 0].max()*1.5, 0.05)
|
||||||
|
xx2 = np.arange(np.array(self.xn_list)[:, 1].min()*0.5, np.array(self.xn_list)[:, 1].max()*1.5, 0.05)
|
||||||
|
xx1, xx2 = np.meshgrid(xx1, xx2)
|
||||||
|
|
||||||
|
f = lambdify((x1, x2), expr, "numpy")
|
||||||
|
fx = f(xx1, xx2)
|
||||||
|
f_xn = f(np.array(self.xn_list)[:, 0], np.array(self.xn_list)[:, 1])
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
fig.add_surface(x=xx1, y=xx2, z=fx, showscale=True, opacity=0.9)
|
||||||
|
fig.update_traces(contours_z=dict(show=True, usecolormap=True, highlightcolor="limegreen", project_z=True))
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
frames.append({'data':copy.deepcopy(fig['data']),'name':f'frame{0}'})
|
||||||
|
|
||||||
|
line_marker=dict(color="#de1032", width=5)
|
||||||
|
fig.add_traces(go.Scatter3d(x=None, y=None, z=None, mode='lines+markers', line={"color":"#de1032", "width":5}))
|
||||||
|
|
||||||
|
frames = [go.Frame(data= [go.Scatter3d(x=np.array(self.xn_list)[:k,0], y=np.array(self.xn_list)[:k,1],z=f_xn)],traces= [1],name=f'frame{k+2}')for k in range(len(f_xn))]
|
||||||
|
fig.update(frames=frames)
|
||||||
|
fig.update_layout(updatemenus=[dict(type="buttons",
|
||||||
|
buttons=[dict(label="Play",
|
||||||
|
method="animate",
|
||||||
|
args=[None])])])
|
||||||
|
fig.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,202 @@
|
||||||
|
import copy
|
||||||
|
import time
|
||||||
|
import sympy
|
||||||
|
import numpy as np
|
||||||
|
from scipy.misc import derivative
|
||||||
|
from sympy import symbols, sympify, lambdify, diff
|
||||||
|
|
||||||
|
import ipywidgets as widgets
|
||||||
|
from IPython.display import display, clear_output
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
import plotly.io as pio
|
||||||
|
pio.renderers.default = 'iframe' # or 'notebook' or 'colab' or 'jupyterlab'
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
class gradient_descent_1d(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.wg_expr = widgets.Text(value="x**3 - x**(1/2)",
|
||||||
|
description="Expression:",
|
||||||
|
style={'description_width': 'initial'})
|
||||||
|
self.wg_x0 = widgets.FloatText(value="2",
|
||||||
|
description="Startpoint:",
|
||||||
|
style={'description_width': 'initial'})
|
||||||
|
self.wg_lr = widgets.FloatText(value="1e-1",
|
||||||
|
description="learning rate:",
|
||||||
|
style={'description_width': 'initial'})
|
||||||
|
self.wg_epsilon = widgets.FloatText(value="1e-5",
|
||||||
|
description="criterion:",
|
||||||
|
style={'description_width': 'initial'})
|
||||||
|
self.wg_max_iter = widgets.IntText(value="1000",
|
||||||
|
description="max iteration",
|
||||||
|
style={'description_width': 'initial'})
|
||||||
|
|
||||||
|
self.button_compute = widgets.Button(description="Compute")
|
||||||
|
self.button_plot = widgets.Button(description="Plot")
|
||||||
|
|
||||||
|
self.compute_output = widgets.Output()
|
||||||
|
self.plot_output = widgets.Output()
|
||||||
|
self.params_lvbox = widgets.VBox([self.wg_expr, self.wg_x0, self.wg_lr])
|
||||||
|
self.params_rvbox = widgets.VBox([self.wg_epsilon, self.wg_max_iter])
|
||||||
|
self.params_box = widgets.HBox([self.params_lvbox, self.params_rvbox], description="Parameters")
|
||||||
|
self.button_box = widgets.HBox([self.button_compute, self.button_plot], description="operations")
|
||||||
|
self.config = widgets.VBox([self.params_box, self.button_box],
|
||||||
|
layout=widgets.Layout(
|
||||||
|
display='flex',
|
||||||
|
flex_flow='column',
|
||||||
|
border='solid 2px',
|
||||||
|
align_items='stretch',
|
||||||
|
width='auto'
|
||||||
|
))
|
||||||
|
self.initialization()
|
||||||
|
|
||||||
|
|
||||||
|
def initialization(self):
|
||||||
|
display(self.config)
|
||||||
|
self.button_compute.on_click(self.compute)
|
||||||
|
display(self.compute_output)
|
||||||
|
self.button_plot.on_click(self.plot)
|
||||||
|
display(self.plot_output)
|
||||||
|
|
||||||
|
def compute(self, *args):
|
||||||
|
with self.compute_output:
|
||||||
|
xn = self.wg_x0.value
|
||||||
|
x = symbols("x")
|
||||||
|
expr = sympify(self.wg_expr.value)
|
||||||
|
f = lambdify(x, expr)
|
||||||
|
df = lambdify(x, diff(expr, x))
|
||||||
|
self.xn_list, self.df_list = [], []
|
||||||
|
|
||||||
|
for n in tqdm(range(0, self.wg_max_iter.value)):
|
||||||
|
gradient = df(xn)
|
||||||
|
self.xn_list.append(xn)
|
||||||
|
self.df_list.append(gradient)
|
||||||
|
if abs (gradient < self.wg_epsilon.value):
|
||||||
|
clear_output(wait=True)
|
||||||
|
print("Found solution of {} after".format(expr), n, "iterations")
|
||||||
|
print("x* = {}".format(xn))
|
||||||
|
return None
|
||||||
|
xn = xn - self.wg_lr.value * gradient
|
||||||
|
clear_output(wait=True)
|
||||||
|
display("Exceeded maximum iterations. No solution found.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def plot(self, *args):
|
||||||
|
with self.plot_output:
|
||||||
|
clear_output(wait=True)
|
||||||
|
x0 = float(self.wg_x0.value)
|
||||||
|
x = symbols("x")
|
||||||
|
expr = sympify(self.wg_expr.value)
|
||||||
|
f = lambdify(x, sympify(expr), "numpy")
|
||||||
|
xx1 = np.arange(np.array(self.xn_list).min()*0.5, np.array(self.xn_list).max()*1.5, 0.05)
|
||||||
|
fx = f(xx1)
|
||||||
|
f_xn = f(np.array(self.xn_list))
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
fig.add_scatter(x=xx1, y=fx)
|
||||||
|
frames = []
|
||||||
|
frames.append({'data':copy.deepcopy(fig['data']),'name':f'frame{0}'})
|
||||||
|
fig.add_traces(go.Scatter(x=None, y=None, mode="lines + markers", line={"color":"#de1032", "width":5}))
|
||||||
|
frames = [go.Frame(data= [go.Scatter(x=np.array(self.xn_list)[:k], y=f_xn)],traces= [1],name=f'frame{k+2}')for k in range(len(f_xn))]
|
||||||
|
fig.update(frames=frames)
|
||||||
|
fig.update_layout(updatemenus=[dict(type="buttons",buttons=[dict(label="Play",method="animate",args=[None])])])
|
||||||
|
fig.show()
|
||||||
|
|
||||||
|
|
||||||
|
class gradient_descent_2d(object):
|
||||||
|
def __init__(self):
|
||||||
|
|
||||||
|
self.wg_expr = widgets.Text(value="(sin(x1) - 2) ** 2 + (sin(x2) - 2) ** 2",
|
||||||
|
description="Expression:") #style={'description_width': 'initial'})
|
||||||
|
self.wg_x0 = widgets.Text(value="5,5",
|
||||||
|
description="Startpoint:")
|
||||||
|
self.wg_lr = widgets.FloatText(value="1e-1",
|
||||||
|
description="learning rate:")
|
||||||
|
self.wg_epsilon = widgets.FloatText(value="1e-5",
|
||||||
|
description="criterion:")
|
||||||
|
self.wg_max_iter = widgets.IntText(value="1000",
|
||||||
|
description="max iteration")
|
||||||
|
self.button_compute = widgets.Button(description="Compute")
|
||||||
|
self.button_plot = widgets.Button(description="Plot")
|
||||||
|
|
||||||
|
self.compute_output = widgets.Output()
|
||||||
|
self.plot_output = widgets.Output()
|
||||||
|
self.params_lvbox = widgets.VBox([self.wg_expr, self.wg_x0, self.wg_lr])
|
||||||
|
self.params_rvbox = widgets.VBox([self.wg_epsilon, self.wg_max_iter])
|
||||||
|
self.params_box = widgets.HBox([self.params_lvbox, self.params_rvbox], description="Parameters")
|
||||||
|
self.button_box = widgets.HBox([self.button_compute, self.button_plot], description="operations")
|
||||||
|
self.config = widgets.VBox([self.params_box, self.button_box])
|
||||||
|
self.initialization()
|
||||||
|
|
||||||
|
def initialization(self):
|
||||||
|
display(self.config)
|
||||||
|
self.button_compute.on_click(self.compute)
|
||||||
|
display(self.compute_output)
|
||||||
|
self.button_plot.on_click(self.plot)
|
||||||
|
display(self.plot_output)
|
||||||
|
|
||||||
|
def compute(self, *args):
|
||||||
|
with self.compute_output:
|
||||||
|
x0 = np.array(self.wg_x0.value.split(","), dtype=float)
|
||||||
|
xn = x0
|
||||||
|
x1 = symbols("x1")
|
||||||
|
x2 = symbols("x2")
|
||||||
|
expr = sympify(self.wg_expr.value)
|
||||||
|
self.xn_list, self.df_list = [], []
|
||||||
|
|
||||||
|
for n in tqdm(range(0, self.wg_max_iter.value)):
|
||||||
|
gradient = np.array([diff(expr, x1).subs(x1, xn[0]).subs(x2, xn[1]),
|
||||||
|
diff(expr, x2).subs(x1, xn[0]).subs(x2, xn[1])], dtype=float)
|
||||||
|
self.xn_list.append(xn)
|
||||||
|
self.df_list.append(gradient)
|
||||||
|
if np.linalg.norm(gradient, ord=2) < self.wg_epsilon.value:
|
||||||
|
clear_output(wait=True)
|
||||||
|
print("Found solution of {} after".format(expr), n, "iterations")
|
||||||
|
print("x* = [{}, {}]".format(xn[0], xn[1]))
|
||||||
|
return None
|
||||||
|
xn = xn - self.wg_lr.value * gradient
|
||||||
|
clear_output(wait=True)
|
||||||
|
display("Exceeded maximum iterations. No solution found.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def plot(self, *args):
|
||||||
|
with self.plot_output:
|
||||||
|
clear_output(wait=True)
|
||||||
|
x0 = np.array(self.wg_x0.value.split(","), dtype=float)
|
||||||
|
x1 = symbols("x1")
|
||||||
|
x2 = symbols("x2")
|
||||||
|
expr = sympify(self.wg_expr.value)
|
||||||
|
xx1 = np.arange(np.array(self.xn_list)[:, 0].min()*0.5, np.array(self.xn_list)[:, 0].max()*1.5, 0.05)
|
||||||
|
xx2 = np.arange(np.array(self.xn_list)[:, 1].min()*0.5, np.array(self.xn_list)[:, 1].max()*1.5, 0.05)
|
||||||
|
xx1, xx2 = np.meshgrid(xx1, xx2)
|
||||||
|
|
||||||
|
f = lambdify((x1, x2), expr, "numpy")
|
||||||
|
fx = f(xx1, xx2)
|
||||||
|
f_xn = f(np.array(self.xn_list)[:, 0], np.array(self.xn_list)[:, 1])
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
fig.add_surface(x=xx1, y=xx2, z=fx, showscale=True, opacity=0.9)
|
||||||
|
fig.update_traces(contours_z=dict(show=True, usecolormap=True, highlightcolor="limegreen", project_z=True))
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
frames.append({'data':copy.deepcopy(fig['data']),'name':f'frame{0}'})
|
||||||
|
|
||||||
|
line_marker=dict(color="#de1032", width=5)
|
||||||
|
fig.add_traces(go.Scatter3d(x=None, y=None, z=None, mode='lines+markers', line={"color":"#de1032", "width":5}))
|
||||||
|
|
||||||
|
frames = [go.Frame(data= [go.Scatter3d(x=np.array(self.xn_list)[:k,0], y=np.array(self.xn_list)[:k,1],z=f_xn)],traces= [1],name=f'frame{k+2}')for k in range(len(f_xn))]
|
||||||
|
fig.update(frames=frames)
|
||||||
|
fig.update_layout(updatemenus=[dict(type="buttons",
|
||||||
|
buttons=[dict(label="Play",
|
||||||
|
method="animate",
|
||||||
|
args=[None])])])
|
||||||
|
fig.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue