From 0bccf62d542659c75ac209c68d65d77288bd3431 Mon Sep 17 00:00:00 2001 From: TerenceLiu98 Date: Sat, 30 Jul 2022 17:24:01 +0000 Subject: [PATCH] update --- .gitignore | 1 + .../interactive-checkpoint.ipynb | 75 +++++++ .../__pycache__/interactive.cpython-310.pyc | Bin 0 -> 1931 bytes algorithm/interactive.ipynb | 149 +++++++++++++ .../.ipynb_checkpoints/gd-checkpoint.py | 202 ++++++++++++++++++ .../__pycache__/gd.cpython-310.pyc | Bin 0 -> 7312 bytes .../__pycache__/gd.cpython-39.pyc | Bin 0 -> 6257 bytes .../__pycache__/interactive.cpython-310.pyc | Bin 0 -> 5719 bytes algorithm/optimization/gd.py | 202 ++++++++++++++++++ 9 files changed, 629 insertions(+) create mode 100644 .gitignore create mode 100644 algorithm/.ipynb_checkpoints/interactive-checkpoint.ipynb create mode 100644 algorithm/__pycache__/interactive.cpython-310.pyc create mode 100644 algorithm/interactive.ipynb create mode 100644 algorithm/optimization/.ipynb_checkpoints/gd-checkpoint.py create mode 100644 algorithm/optimization/__pycache__/gd.cpython-310.pyc create mode 100644 algorithm/optimization/__pycache__/gd.cpython-39.pyc create mode 100644 algorithm/optimization/__pycache__/interactive.cpython-310.pyc create mode 100644 algorithm/optimization/gd.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..65e3ba2 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +test/ diff --git a/algorithm/.ipynb_checkpoints/interactive-checkpoint.ipynb b/algorithm/.ipynb_checkpoints/interactive-checkpoint.ipynb new file mode 100644 index 0000000..9167237 --- /dev/null +++ b/algorithm/.ipynb_checkpoints/interactive-checkpoint.ipynb @@ -0,0 +1,75 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "af1d5646-4b13-4039-9f76-8042bc9dbda3", + "metadata": {}, + "outputs": [], + "source": [ + "from optimization.gd import gradient_descent_1d, gradient_descent_2d" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a5c51d2-8b18-4143-b6f1-73a909ccb623", + "metadata": {}, + "outputs": [], + "source": [ + "gd2 = gradient_descent_2d()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f254549d-d9b0-43ac-a502-e65b7c462d05", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "plotly.graph_objs._scatter3d.Scatter3d" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import plotly.graph_objects as go\n", + "go.Scatter3d" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c0f6766-4f41-419b-94a0-916a379ebcd1", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/algorithm/__pycache__/interactive.cpython-310.pyc b/algorithm/__pycache__/interactive.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e98882063529bda2e993eb1f9202f8b781e168b1 GIT binary patch literal 1931 zcmaJ?OOG2x5bmD0AM1T&14}?+6d{OLT6-5n4l6*BAVIlUNP){JXf*3?duN$fy4#CA zj8D#$6MrB%=3nqT`kE7e0i=kk@vdW&6*KDU>YA?VdVJNBZ*4URJh%UOefEb($ZK4z zKNc(=z>rUYFv4g;g5^vD1J*`jrdD84l$i-h?Z8f*z)9V}O})UYo)3bNJT63_6THAzk+`IDrhhR5uwH@nxjqOWw7#D`)9hp2(Uh;?YcdNtB*3 z=z~TW6`3VqC0?G)uO_}kW1ST_?1+%quCa{?cZg~3ESVS!5 znF<+~L!3|8kb?g8gS-vnTNqLRQRIT6Q8XA+X8&$hBs4&3GKaYvQmdAFtg#`rYiX0U zHlz+nUDjq>8&Z!EaI-TgJJZ9%FZ?4Prh6y-&&$@sX(70jah{z9yXd=!3k9q2*O;ku z?Z#QG;wb5w+P;jE3D>Q)P1=@fmT*nM&=VC2RpfD|PD|^AADxt)ghwKav#~EC#h21A z%LbgFpjOdoV4WNvAD6pnG>y|q>Z2G6m}lLFcH#IjS5jL)@~ML5IOO1I(0G*Ok*Z}j z$fm~{r$p!`?kHrOMa27%{=WzP^E~B!_`5k4)xwR z=Ksy@H|0k+fTwqD?i~~%JZlu)2S$+!aI~Ow$r&roTyfTlvyrQjKnGn%xS%g~knAGa zLvjnrK9bu=FjGr%0G#ezq2VeSAA(F`xJ|lzJG8ueQ!JZm8xf)UsxSV81;hgp&8dQx zThO_opf4^gjB#!%yD}}WF^B%p`ctDMRm z83v)yPZb5t&IPR+`{aPk>9fz6gLl$E)2YCIFYi3cCmHi)o=gzeem?SloBPoSuoW~{ zaF?Y!J@mhZPi6n%bjUg5%zvjTWX~VuZ%vKRa<99+tQ}@im`L#MT^1FyDjXo%7Ew0l zby>9&6$Q_jHls6HS6kbK0JLjol#4V{H4M5<@g67!#xxUHWNl&dX*$vd8wKt(3zJx? zz-6PF(KOTEb>07qnbt5*i;1ec=PuZiJ`j(3v<;xyfw5IM@X*rc}~|1eYM rdDgr7>D0}qo&4;A4^@5GO}=UGt7Tl@AeOQfOFC^CIRt literal 0 HcmV?d00001 diff --git a/algorithm/interactive.ipynb b/algorithm/interactive.ipynb new file mode 100644 index 0000000..f5764e3 --- /dev/null +++ b/algorithm/interactive.ipynb @@ -0,0 +1,149 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "af1d5646-4b13-4039-9f76-8042bc9dbda3", + "metadata": {}, + "outputs": [], + "source": [ + "from optimization.gd import gradient_descent_1d, gradient_descent_2d" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0a5c51d2-8b18-4143-b6f1-73a909ccb623", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "923a9d77eecb4d78ade9e3113cf636c0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HBox(children=(VBox(children=(Text(value='x**3 - x**(1/2)', description='Expression:', style=De…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c2252623a38940f0a80471ec28e491d3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "778d4e34c4ae4b0d8f2f1b4e0d3c1173", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "gd1 = gradient_descent_1d()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "15c6e757-cde3-422b-be7e-3f55b7752142", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "298df33c058b49e6927e582d532bb87b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HBox(children=(VBox(children=(Text(value='(sin(x1) - 2) ** 2 + (sin(x2) - 2) ** 2', description…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "47315c498b9a46fd9948fd7a53d3e80a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0cab52875285496a989f685de24515cb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "gd2 = gradient_descent_2d()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "120ecde5-cee1-47c0-b1bc-c99b0cb1098e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/algorithm/optimization/.ipynb_checkpoints/gd-checkpoint.py b/algorithm/optimization/.ipynb_checkpoints/gd-checkpoint.py new file mode 100644 index 0000000..c794080 --- /dev/null +++ b/algorithm/optimization/.ipynb_checkpoints/gd-checkpoint.py @@ -0,0 +1,202 @@ +import copy +import time +import sympy +import numpy as np +from scipy.misc import derivative +from sympy import symbols, sympify, lambdify, diff + +import ipywidgets as widgets +from IPython.display import display, clear_output + +from tqdm import tqdm + +import plotly.graph_objects as go +import plotly.io as pio +pio.renderers.default = 'iframe' # or 'notebook' or 'colab' or 'jupyterlab' + +import warnings +warnings.filterwarnings("ignore") + +class gradient_descent_1d(object): + def __init__(self): + self.wg_expr = widgets.Text(value="x**3 - x**(1/2)", + description="Expression:", + style={'description_width': 'initial'}) + self.wg_x0 = widgets.FloatText(value="2", + description="Startpoint:", + style={'description_width': 'initial'}) + self.wg_lr = widgets.FloatText(value="1e-1", + description="learning rate:", + style={'description_width': 'initial'}) + self.wg_epsilon = widgets.FloatText(value="1e-5", + description="criterion:", + style={'description_width': 'initial'}) + self.wg_max_iter = widgets.IntText(value="1000", + description="max iteration", + style={'description_width': 'initial'}) + + self.button_compute = widgets.Button(description="Compute") + self.button_plot = widgets.Button(description="Plot") + + self.compute_output = widgets.Output() + self.plot_output = widgets.Output() + self.params_lvbox = widgets.VBox([self.wg_expr, self.wg_x0, self.wg_lr]) + self.params_rvbox = widgets.VBox([self.wg_epsilon, self.wg_max_iter]) + self.params_box = widgets.HBox([self.params_lvbox, self.params_rvbox], description="Parameters") + self.button_box = widgets.HBox([self.button_compute, self.button_plot], description="operations") + self.config = widgets.VBox([self.params_box, self.button_box], + layout=widgets.Layout( + display='flex', + flex_flow='column', + border='solid 2px', + align_items='stretch', + width='auto' + )) + self.initialization() + + + def initialization(self): + display(self.config) + self.button_compute.on_click(self.compute) + display(self.compute_output) + self.button_plot.on_click(self.plot) + display(self.plot_output) + + def compute(self, *args): + with self.compute_output: + xn = self.wg_x0.value + x = symbols("x") + expr = sympify(self.wg_expr.value) + f = lambdify(x, expr) + df = lambdify(x, diff(expr, x)) + self.xn_list, self.df_list = [], [] + + for n in tqdm(range(0, self.wg_max_iter.value)): + gradient = df(xn) + self.xn_list.append(xn) + self.df_list.append(gradient) + if abs (gradient < self.wg_epsilon.value): + clear_output(wait=True) + print("Found solution of {} after".format(expr), n, "iterations") + print("x* = {}".format(xn)) + return None + xn = xn - self.wg_lr.value * gradient + clear_output(wait=True) + display("Exceeded maximum iterations. No solution found.") + return None + + def plot(self, *args): + with self.plot_output: + clear_output(wait=True) + x0 = float(self.wg_x0.value) + x = symbols("x") + expr = sympify(self.wg_expr.value) + f = lambdify(x, sympify(expr), "numpy") + xx1 = np.arange(np.array(self.xn_list).min()*0.5, np.array(self.xn_list).max()*1.5, 0.05) + fx = f(xx1) + f_xn = f(np.array(self.xn_list)) + + fig = go.Figure() + fig.add_scatter(x=xx1, y=fx) + frames = [] + frames.append({'data':copy.deepcopy(fig['data']),'name':f'frame{0}'}) + fig.add_traces(go.Scatter(x=None, y=None, mode="lines + markers", line={"color":"#de1032", "width":5})) + frames = [go.Frame(data= [go.Scatter(x=np.array(self.xn_list)[:k], y=f_xn)],traces= [1],name=f'frame{k+2}')for k in range(len(f_xn))] + fig.update(frames=frames) + fig.update_layout(updatemenus=[dict(type="buttons",buttons=[dict(label="Play",method="animate",args=[None])])]) + fig.show() + + +class gradient_descent_2d(object): + def __init__(self): + + self.wg_expr = widgets.Text(value="(sin(x1) - 2) ** 2 + (sin(x2) - 2) ** 2", + description="Expression:") #style={'description_width': 'initial'}) + self.wg_x0 = widgets.Text(value="5,5", + description="Startpoint:") + self.wg_lr = widgets.FloatText(value="1e-1", + description="learning rate:") + self.wg_epsilon = widgets.FloatText(value="1e-5", + description="criterion:") + self.wg_max_iter = widgets.IntText(value="1000", + description="max iteration") + self.button_compute = widgets.Button(description="Compute") + self.button_plot = widgets.Button(description="Plot") + + self.compute_output = widgets.Output() + self.plot_output = widgets.Output() + self.params_lvbox = widgets.VBox([self.wg_expr, self.wg_x0, self.wg_lr]) + self.params_rvbox = widgets.VBox([self.wg_epsilon, self.wg_max_iter]) + self.params_box = widgets.HBox([self.params_lvbox, self.params_rvbox], description="Parameters") + self.button_box = widgets.HBox([self.button_compute, self.button_plot], description="operations") + self.config = widgets.VBox([self.params_box, self.button_box]) + self.initialization() + + def initialization(self): + display(self.config) + self.button_compute.on_click(self.compute) + display(self.compute_output) + self.button_plot.on_click(self.plot) + display(self.plot_output) + + def compute(self, *args): + with self.compute_output: + x0 = np.array(self.wg_x0.value.split(","), dtype=float) + xn = x0 + x1 = symbols("x1") + x2 = symbols("x2") + expr = sympify(self.wg_expr.value) + self.xn_list, self.df_list = [], [] + + for n in tqdm(range(0, self.wg_max_iter.value)): + gradient = np.array([diff(expr, x1).subs(x1, xn[0]).subs(x2, xn[1]), + diff(expr, x2).subs(x1, xn[0]).subs(x2, xn[1])], dtype=float) + self.xn_list.append(xn) + self.df_list.append(gradient) + if np.linalg.norm(gradient, ord=2) < self.wg_epsilon.value: + clear_output(wait=True) + print("Found solution of {} after".format(expr), n, "iterations") + print("x* = [{}, {}]".format(xn[0], xn[1])) + return None + xn = xn - self.wg_lr.value * gradient + clear_output(wait=True) + display("Exceeded maximum iterations. No solution found.") + return None + + def plot(self, *args): + with self.plot_output: + clear_output(wait=True) + x0 = np.array(self.wg_x0.value.split(","), dtype=float) + x1 = symbols("x1") + x2 = symbols("x2") + expr = sympify(self.wg_expr.value) + xx1 = np.arange(np.array(self.xn_list)[:, 0].min()*0.5, np.array(self.xn_list)[:, 0].max()*1.5, 0.05) + xx2 = np.arange(np.array(self.xn_list)[:, 1].min()*0.5, np.array(self.xn_list)[:, 1].max()*1.5, 0.05) + xx1, xx2 = np.meshgrid(xx1, xx2) + + f = lambdify((x1, x2), expr, "numpy") + fx = f(xx1, xx2) + f_xn = f(np.array(self.xn_list)[:, 0], np.array(self.xn_list)[:, 1]) + + fig = go.Figure() + fig.add_surface(x=xx1, y=xx2, z=fx, showscale=True, opacity=0.9) + fig.update_traces(contours_z=dict(show=True, usecolormap=True, highlightcolor="limegreen", project_z=True)) + + frames = [] + frames.append({'data':copy.deepcopy(fig['data']),'name':f'frame{0}'}) + + line_marker=dict(color="#de1032", width=5) + fig.add_traces(go.Scatter3d(x=None, y=None, z=None, mode='lines+markers', line={"color":"#de1032", "width":5})) + + frames = [go.Frame(data= [go.Scatter3d(x=np.array(self.xn_list)[:k,0], y=np.array(self.xn_list)[:k,1],z=f_xn)],traces= [1],name=f'frame{k+2}')for k in range(len(f_xn))] + fig.update(frames=frames) + fig.update_layout(updatemenus=[dict(type="buttons", + buttons=[dict(label="Play", + method="animate", + args=[None])])]) + fig.show() + + + + + \ No newline at end of file diff --git a/algorithm/optimization/__pycache__/gd.cpython-310.pyc b/algorithm/optimization/__pycache__/gd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33198ff3138ce14d5d58bde9196aa3b68721ebd0 GIT binary patch literal 7312 zcmb7JYjY!4TJC$dTCyzL@^voJ=0c*$#J3PGA)6U;nIuK<4kYX@y_*$notE61mRfVV zH6v?eD9Q=R0^Vd{`LH*Tp`glDeBuYdAAmp5pC~A5E5#O2l`oW`c6i>?ZCjqvB!yh( z?Q`C9>F#sh%k#9usVQ5+@80&mZ7)0}N&iZX@t=mqb(HAuA(+JEKyuQn?8u_6IEuKc zj)toe=xxI>WE!IeO51eIw&hsuoRe$YjxF?Zkar5`*Mg~b(J8j4ooS)hgHpTXl(I21 z?OA6otC!pJ&I0O&x6oRwEs6UrJImhkXoZ=FvMjwJF^lEimsrlbu&+87(Y9G0?YwAT zV)8puweTt4LE5UyiOoFj-*IFAj#t$ZGumx$he0GNUB9uLSb^K#W>nK48&xGSnICln zcQ+~21JC8PuoriGF_zNe+pL`!euKMh57%ZVe+z*T1^vv4(74i9Z7b4~6t-|NIR=0$b7Zm_zHbxwX%&!N1E60M>YOD&le zr-gl)=?CgiswvQ_%)sK~T1{xNcjvUa&|(MAX$@#iR$x=-v=+2ER%Fu$GVWrFRw--c zV`)@4l%StFtrbz59!f0*MMKe>>LJZ6WqJcTtEmXx3|5;(yM!{+RM-sG#k|^FrnjNb zLqD78%h1!>b13B|z41JoV`Vna7T6+NI?&m&n7=TTMvI3MSr@i=Wot0qzjEdI%5xQ5 zp4nL2tRClI?{~QuMSj?M(V3x*=YBUvt9I9C@lIm;9Y6NnpsFVN9XIHCNq%flqDS#= z;3e|rasHjy<#9LkJMoLh+J^Vs#&MA>sN;8<74F7fs(s{9=N-v>1}B|pH-j)v?3*rG9!5op9d@(9k)t&Nub&w8FzB^AgIp8_ zKC5hY`;HmK+>7fwiRSj=u&Q$ti#a(OUTXy5-Ne`qIsC)Py8*1Zh6iXzi7s9n_HMul zVEomk#DuEp#ZjWY>-A$an>7!cmgH^(p&JXG4qbm8lbf}GCpL}fMt(pq^HwJvl7|xS zNw?vN@k%d_(OcXW*IGS&<0PN8yY$A!SA_2+#jH2;#UxL?Ck^eZuY~=i(520ZYQdfD zu@)CC?M)hCXCtWJ&hA2W<8ASpw1J(5-%N_xdHVx#7)NmY$O{@r62FG+JO0hJov`h# z;T^nA-3$EQ8uRXWLD+4>U)ON!C zxIM*}&t8Uz;n6L5A8{8~WrXO`(kKcGy`Jd}p%>9v=q=Kdza7b2M>%12{8z9TUxRQg zcvd~|>$eg!J666)o$C;ZMkmXkr}mdg(3Yp-I@Ml+coF3@N@X_QmP%0`JU;0S&!6MD z_oNtVp}c_-Eg_ELy`+zSPrRkvQudTsWq9j7H5PHJrJ%0G#!ww<`^uh9wb&eLLv>$4 z+h|HnNug+Eq2yva&NGU61I?im(g{+_IkhWdCX%K_`l!owa;Q*v|8B;&z5S?%B$7+)Fi67 zEiaKEqvV88_|qhG5(Y$K_B*w}j}VktLsU4OZK89x)AaarR5#pi*XyuEb+;q_EX|~& zPxLN_wfV}v3<`P0FLlO@N}|7zwV=ge1FKUSy5 zjlPb<`ACNO{sjy}fB|7N3Uq2qgUKne&QxK{26U#-X|V+pHek*&)Am)EbZ%s$s0R|vf(lMip1PC5< z8kqlIKV<$Fw%+A$VV6@I&^x_$clRe~(TC1X@*1tjb+6 zy&WyUJ|^fXUe6zk{_tr2xR|RHOSGsp3@}BSx#6Ci)G5j?H)K zZcOn&>2wp_<=oxnTQpNil-2b_zICJkx76TVYA0m;x6+5x z;#~l%d(@Dfe6Nc`^pMzlQT0-qmD11w^n{S{`QeFU5yYM}5JhN+4M^0s-+^%ZJQM?N z_{|>o0KpimMd_yT-=H3?9uhQQp4SyMn?__Cew&uNNu$)j>m)|{pvhEvt);+?7)?C} zdf5r@R_%17=`|74)>DJ}8ZGm6u}r_ektmIRq5=Z(0qI5Ry-T(0ByK>^(CFtNECjHE z2y7*J71xq%DFi(#@@4rVdY(l8n}U*L37oS-gWAhAediNrFA6%rRnTqJRc#AOnXK~xv{ zQ&cA-Ord~q!e^){GPua^lzBz&=c`mB*eT#8-=La6n*w6~JT+e+@e3q=k;E^N_+=6T zvJ&E)2t~;>$%7+W&QKKDy8kFygl~+4^fEkT0w~{PUku9mEA%{Mh@4{l#3!Ri{*XG~ zfS3S{ai1@M#sr~gazu#>1}1&{w}dSOkq4$2=}bP9V}mINfXYJ!Fd3g3kc%~3332E^ z7x+{lkj5D}28FNX<3c>ebo3U{n}XwD&!n~EX&?sZMkNB)Ob1G!x{kU*b)UXvFiJJjo+KH@V^xK6Ip@TfMzLw(c@3(yWP~`|ln?9n)LK@>OIpol!bMyhd#9 zZweH_13K9@L{&~y(Dw|V{2Hm>h8Ro<)cuWndshLy-{ZedLswFxrP{+xu!Net{Fts-dvJV~F(A6w^Pc1wj`Yw06{!QjXN z$SGorfgQs!EHE~5j2uSUEz;?+GRmXSJV0#%ydNmcYygmr2zJAK9??q>JO>DF69k{Q zO8#?KNqfgvF|i7{oq4Ey6Fe4FR@&}OrOD$Frcz|wL` z6);bH<-pN#K3dRu=@!Wqv zwGSZ@tL;TQP42Vws|jHi{t-k*IeNSS4(FUuhm2r+Ort2a^KX;*LlS>P;;%>$^Wc9> z;(!DNWMO*MsWX<52KO_I7jgYhFw#-`$Yn$&_@9#a9tmO`PM(YmpNNSNPof`xgF(@Y z5H>>lQWp4c$TvVgO85r42@*nIK{cf!7f>!D30UHLs3sdfE8MrWC;w#vCCScei~kut z$XUKQJMZtKgG}a(T`YlJEOF9TFN+<*KZoFkj*kB?b$2;Au{b!#j_Q7QceU-~zY0hp z<0+OHh4_Z$*cg(1VHfB&nSSf$E?8-2HTzdo`U%yd36zPQG8tE}y9#>M-KmA!u&mgiYO^I`~bxLv~8rz3o0ev mc+wUn)=Mdadzk`TRDlpIOr&O?C736GHAkavJ=Ok$GIN$B)W_00CTdwN{m zW6w@z1(L(ja&8Ee!>q(*FPu1V=7t1_bBhBKNRcHj+|VjY`2Ol%d)8jZAR%T{e^>og z{a1C>|M>oLD-}z^@3;T_`-a_+q`y&P`lq4s3`+cK5GFA>knHR!+X~uBpoW^Q$;49w zDb#H}G;AX**oDxvO~IFgqHUpH3rbnb{^Dfv6-oj*&6?SAda9EGm7!y_K4Q8*VVFX`t<wP>R0pMu6SDDfE(PqJkuVY_6ed??u}Q<;XR#&l-jsj~tz z@ibTwi?-7F4tJUF^%949af_a-tR+Lv{*UT9%5y03akLVtE0b^%?8{8Qu8t)~0as-P zgimud!NuMk;Oc^l9X!A_z%8&cs~q5(;1*ex&0Lo;izQmMyj4o1NqI*CzjCWqMQvs* zbrlp1MenFP^kyyR8{nB8Met@JZ4T`k%4|nrvyhATocWw@fnNfDF6Y;C8S~)Ra~Vg* z(&XrlM4E){oL?W!3{Rgv*Erk2HTA-%Fm{(^kE+t9w<_mY(11G`O31 z8TW}j3o|58K}?uxt1BxjdzH`~Hi!qMjCyI|xhRBwdAl;N;#5=ZLJ%dXb-^Vq!>Bm5 zqJBO&ZZ4z+thD1LajISPh6$P-$HPaZ#TSFfO$1K|Z@7}uBjBk;Bl@u)&}6UlvLPif zu?~mUks8kqk_5fg4e@kZ*^1Iq-tN0V<)(> zG1cOtrM*letb7FZTlp+hH(F7z?RU~@eh&UfY|<`_FZP1=uEd|f9__upycvbwGS=?( zT3+A}mYH|O3!;7qn_0%0V82>q@@FHQZkS71c7slY)7T7`53oCIsSiisI5er_%#&?B z1tQ6YtjNY^rh1>JR+X!YC0p{MTt^`~I(djGFYQlpmrbjfP1y9eQAuD$UHKY(9iGYr z7Twjz^9jD5^9{il9#Qa3;*;L*%ImvDUVt3FO5`gbwh0Yy1%B&NTF8%&H>vY6kW`}+ zZLcy0&2vj`a`%1b1Rd>6h;`bfI*srXL0 zt86QY%JALWY9hQvS3zA%jIlb_t|{9()sn(k8>`n8w2h9`kreVcCQ31}k`g0NQ^tFE z^LgpLpKce%#+amFMAw{D#_%;_wj%1L*wgBm`ewSyCFz>X|D4pw<6WoSo@M%EZj7BD z7npHP6`TS%kU@JbOK)MnHTRsd*$*iUy|T6(XdToNxetkO)N{Zh)kP zVb2Nt819F)MTOJpraE_f9gjasb;IrVy&g+dcO&Kx)0=eosov+X&(vr~JaiNOCF&q$ zZ5H`gsb6z>Cr*`NkCO?en(%FM8zHov%DvRwSGqg3^hu2V21B|FLhh}+f58kmYU#%BpcEqf3WcHLSGyY2O-Xd73Pq+6 zR}efcF`>Q&R9R-)H5KYyoLDF&j41EOlgf_7^tv>~se)tFCBPna+ZgL(0}>0#%%rv> zL7~m=ER-Ag;kp8qM(g&QW0IZA#cV}pU5AQ}&B+`E$7vPp)mEkWPj{(bY5gL98XuXR zf!-U0{jE=EUnS|U=Q~;DU*|vh!?mk_#|s_5XadGe?u@LYqPsZ=m;{`j#raK6FVaR;s(4yIcHO5~qfK55yg&N+(K<7yZtFdv?iX z%!ylWLRO*0o6*%=vG^1aJZT(?P01S@Z70Y-@=I3P~t@p2^wQT zr%11A@{(*SCUoghd9-+#7Lik3-z`vF=^>IkPHH(0wqOuYz3e!b2X2t{@UKI-@UxuM znA5CND-r%!xZ)Bz=1^kVJ}CHuX2D$G#BHG*`2BBT=-L;+&`&WhKw1t%1!yAx3M~i9 zRYhAxTgw2cXcOqv2*&6A8Qc^gRP<*+6Yz5iXjb#q{eUKg%l`&w&Ok0K_b#AW%lY>Q znzQSp$4R$39z}=Y@9ybSZh4q$e3u~V)8OJ3+5ucE5W^yzPfT|b%obZ=d(oU zd^uqdUMF&d$WbB-L`V<$F(Su_oB(Mq@RL-3fQazZ4XQmvM7USsd>^4EIcEN4B7_P= z^ylO+Grunoz~j`U)P*lfq<^;BOv+4+Z0AbmW1#z(V4s>Ck9il0}+iC zpdsXOaQzyd6kK!!669ebsNBXt$d8;-QcfyNM{gCq83sxbG9WvdL9hUCQX~At^kdQv z)pgVjs*8U8-u)dVUlGl1O=ic)$^mJB3&XP@V!d4bs~3QeF{H6L5a!Q2}Lc+ z8sZ&AJVH^A%8kzz*t4OO!8tC_)Z!PRDcHkdXo_}*vI$`b08v>gAm8}aK|IwJsZUp* z7!2x0>Q5$-m<&k}%)$gLC$lJ$N_L5-Ckk+n!V1Tv60V8rfaFxC8Y>ct%KNRUe$;RE z7tmiORJEhLiv$CvQ=#Oet&DY60%im1qFy8)#j%O8w7%*NZg`mVGtAUhj!AcME#l%0 zD@SgnaVbM?<039@D8}h9q*8Y-nMYg3^@@xu4vp0Wrkuw`$Q+x;cb74yewXwwm~(8g zdiM6m3Oka|a)cda3v3ZrMGNd0Xn`GPC)ml4Ra~N3>;X}G@MDcVBv553yH1%ltKOV2 zs^8pYRE>Si*ulZ_O(M6oj;x9YM1DZz5)r~xoj?A8e*ZY%)TfZ5JhF=f-GJi)&n`rL zx8)~Wb}{fnufsjB*VOoXc-t-wVy_hiluGpN>Zae>3{VomvWtBlb-h;NjGD4-;bI|) z20V5~qrxN1TU|M~c9Zg%ref=2Jg!UV3W&c!%&FDz2n>D^B-IGq+U3{}`v7rQH<5JN zMlV5H9p9v2eN!XyIez>CYSZXfNaX+ad@X_pxko zJ%sB>Gx0DTZj^BQgeyT0;r=v)i)@M?;&FJ7dJpqAZ=u6BLoeR!aGzz_obW;K&>Qbk zZw8+SyzRC;yE5oA%=uWx zzK9TZoz@s^Eg{|LZ#vNiX?JW_4m3F(CUid==|`#JN4DArDdMsWm*I4ySYTe;9Rx{g zUd^s7^^NP1444?7g|^B~4<-Zk(iOH!k8PEt8bxYA7vbQj~|3 ztyQ!r3MY?7P@pd@{IL6Ipili3Ezmv}AW)z{`eGzN`@9I^erL#CYwdMX33E7eW;ipP zIp=)eOkAzn5`Mq??B5&DJSj<^QepC^p>P2yxdXu@CWn%fKV?TjSqarhb2OQBYA8jz zqeq5gM5bd#mSYLM9F`m#^;%etDo!P;I#r?9!&+2xYDJr=Xxf=6^0U#LQ%BzL>)m;G z;dP0b%=$=TmVfM)>MWvMVm8XQC?99?+fuW94>Oe3n{s9|p9eR*G`Qh6waiSmqK!CA zM5Z6Kw=ygAq76nl4YJ)-GLr>KKlHY;N-Oj|?#6?(KSu`?+elK;IpS0+?%33ns-~Ulw zN4kWRU}2?H>dMq9br{G@zpaiXSAkY#2G(U#s|l@z)tS`lLTj_~A*}(e$ttXRNNYh` zVl_5(TgE7MD(#eaCFCl|RmW0SLDG=)j=D=NYK7i_&gv*aHwE1^$~B~^LN|kcXHlL( zvO5Z!7X8nmG`CY{Gi-KOkui=a&9Qo3+L_;#$Reu)01>7^fP&Ful#oE z@w6WYz4XOBZPh=!x>qAR=mot_gL|o;YhT#2F+vLC!ho5!y0WseSB<=3gLE*rxR;ri z;s}=G-7@!-J=?(^j5#EUuIk!q$z4^ z8{+A<@|kAFTf*40TA?V6Khxg17!R{b--GEVZg^v3qQpf>dxKioMGMqBrB_?6m@`_5 zlJ_8+_^F!)DHgj{Y)LQ@>vb0%miS@&uEf8CHQ0M^c{7guWlY!awfry`EHnRxAIAL% zjsB^g8Lu z;_u4qO@o_g&zDI&2a#E@?N%7Hu4QJiTf9k?Yb0m|bMXS@Rw2$HB?g2o<~@4)s97K0 zl0$PZ;Z+Oib?RpkW{SB=pZrG5T<$8{N~$t^$+nt`n9x;_*HUAwjg@}WQB1S7AZ;4f^jj3*` zt6Yo>#no!K$ga%?i=mMSzRH0Xm?RhHhx%4+yi9j)=0MxR4Wv+B~K zw5?oUp6IR4EPAW-$XlyY^6^d4pDg`yKR%V#GkGXn{3;psO)p6IPF#)$J=RF#a6rMK z5w{y3Za2I(d82(0Gm?=xJl%L1ReQ^?4O_m?eAa;12GJmT5LuF?#(Mmqt2Xt#)SS<^ z|C`tzUL!$<#tF0eQzQsSIPHIC4ts7GBnTy}Ei#<;Hq*J+>-hXzlsCM7-|w+Z^)?dz z482KPpXq%LN6L(L%p))5Pg4b%WwXT3QN8BzPLe6Z9w(Q|G!dudcBZg)CigOH-{c;) z&mW*Q*=RyWYRUAk$hJHUOI=VXUqRlMXHiz<>Hp5o*l3~1&wf67V!z7?2wf`t0Zv1X z-UD2_Zf3h~6th7{`HJgaA9!J2!@r9`g^O^S$Hd%(7YPR{qk@8XLd#&$6i-%kPTCgI zq2B=!pl|tlB0yZmF<1$>W*c_?`8Q-SG1KS>4lVV~&0 z3}Al>qtKk^Fejn|Hd2}BK%uV}eGoCI7t~>XEbS~D5P_NX(Q~Jgpm%Dx+5{O`Yc@`w zZmcz)Z{$jp50t!u#cC?KXV6ZL^wpI!t1EmKc|J#?PGX+K0*PZJ7D*f@aRS28u7q)# z*;fb$eYkGo*zrD6C(UCVtXbmUAR$~%_+o>K!c&AF5j_+>#wq^ip}btca38(%Q>3Oy zo|UOwKrlZ|>gP$2UGcLdmLMFvc$qNSXbCQn8}m^>K&Jj!A+&YzPQ64gwF3Lz zMlsXGGG3*cqn3c;%mz|IhaJMg$KoMA?Q8Im;-M7w@sZ03>hzLh;%V{O8G99%w)aHm7ZkyD1It5&W4!jy@ z%Jk#XF6DLP4a$pp{qgl3rI-8vI2b%B+4j*Sh0uN8r>3#C*Y$dOjPzB1A8aaMaQ(*F4oq@7yd zlg|KSb(U`ReE`oY(wcx?eiD0gkAg9Xw+bS9mH!Y0{xS)&$5B;4?Ry_?p8?)}z|T|D z6S?Opy?&;9oO@gR08ep7=!Zd?>20Fj52Nb^>h@Ih+y~RWUr{QkBWztHZ2iydY=Nzx9f7Uzhyq&=QNt=mAa4?`@Rxm8_y@TH z9ae?GC;}g}uD)Y17#tiyBd16$2AT`Uut3v@FqX(DyQI@oWv4t=n0Z_(cfn)4Aq{&XOC)mk5 zDzou_QslmIM`KUjrh|1kN6?1}FsST5A91k8df_(jVte^7NIbAVY_!DRhv@wA5A?fx zUN~K*_XbgatMm82(eLx~O+9CW@>V8~oKmtG-%MIw=sRZI_gX=^<&?r8@;lu3d$^;~ zsDV=+Bz`LniD&km+Gfz%43Scya!P$3cl}oCj+(M#1M|{&z!P^gGM{1o>dLvbdqj3| zm!pgJI2Y3$z^Dq>^ONTzk6#0H7t`WCeMK80)4&h?Orus#B?+QFxZ=%CkaNfAr8sCM z_lP3gBT9YmK7Ji!@cY;KrYgShL083s{PlcObMztjd*bLL@vvK94%~D7E#E17NgrP! z4*r;|mHHSGHw=zWmq9-Gga9$3ocvcLToMNXiT|349toAi0o4b47gRp8B0t&e@POsy z{{mWNN+-^Y%XHsB6JX3u2E6UHe5X3-GYsxxJ}n=Y4J>G-PF1vMc`3-wId;$n*JHRA z>afK zEmHClgoPNjSVXV}qSv%M4RW$3TR4j^;(V}%dgH#OXozBlVxVNp+84%`w)OOx8MYv z(d)E%j-9lE{?<|y;A)CsX=2mJS_M!zHkuUIJ%UVTQ*U0`N;l))QgJJp6Chy{CF-?K zUlDFC;RE}dZoENmkvP>uMV?U0QqeMqGX+WQ2eDJa-3l&9;SD$o`fYCz;&7LDA`Um= kV~+az%z7z5Oua%uAR(GeOF&R!Dbq4Q1+Zdai|Q}_57)S5`Tzg` literal 0 HcmV?d00001 diff --git a/algorithm/optimization/gd.py b/algorithm/optimization/gd.py new file mode 100644 index 0000000..c794080 --- /dev/null +++ b/algorithm/optimization/gd.py @@ -0,0 +1,202 @@ +import copy +import time +import sympy +import numpy as np +from scipy.misc import derivative +from sympy import symbols, sympify, lambdify, diff + +import ipywidgets as widgets +from IPython.display import display, clear_output + +from tqdm import tqdm + +import plotly.graph_objects as go +import plotly.io as pio +pio.renderers.default = 'iframe' # or 'notebook' or 'colab' or 'jupyterlab' + +import warnings +warnings.filterwarnings("ignore") + +class gradient_descent_1d(object): + def __init__(self): + self.wg_expr = widgets.Text(value="x**3 - x**(1/2)", + description="Expression:", + style={'description_width': 'initial'}) + self.wg_x0 = widgets.FloatText(value="2", + description="Startpoint:", + style={'description_width': 'initial'}) + self.wg_lr = widgets.FloatText(value="1e-1", + description="learning rate:", + style={'description_width': 'initial'}) + self.wg_epsilon = widgets.FloatText(value="1e-5", + description="criterion:", + style={'description_width': 'initial'}) + self.wg_max_iter = widgets.IntText(value="1000", + description="max iteration", + style={'description_width': 'initial'}) + + self.button_compute = widgets.Button(description="Compute") + self.button_plot = widgets.Button(description="Plot") + + self.compute_output = widgets.Output() + self.plot_output = widgets.Output() + self.params_lvbox = widgets.VBox([self.wg_expr, self.wg_x0, self.wg_lr]) + self.params_rvbox = widgets.VBox([self.wg_epsilon, self.wg_max_iter]) + self.params_box = widgets.HBox([self.params_lvbox, self.params_rvbox], description="Parameters") + self.button_box = widgets.HBox([self.button_compute, self.button_plot], description="operations") + self.config = widgets.VBox([self.params_box, self.button_box], + layout=widgets.Layout( + display='flex', + flex_flow='column', + border='solid 2px', + align_items='stretch', + width='auto' + )) + self.initialization() + + + def initialization(self): + display(self.config) + self.button_compute.on_click(self.compute) + display(self.compute_output) + self.button_plot.on_click(self.plot) + display(self.plot_output) + + def compute(self, *args): + with self.compute_output: + xn = self.wg_x0.value + x = symbols("x") + expr = sympify(self.wg_expr.value) + f = lambdify(x, expr) + df = lambdify(x, diff(expr, x)) + self.xn_list, self.df_list = [], [] + + for n in tqdm(range(0, self.wg_max_iter.value)): + gradient = df(xn) + self.xn_list.append(xn) + self.df_list.append(gradient) + if abs (gradient < self.wg_epsilon.value): + clear_output(wait=True) + print("Found solution of {} after".format(expr), n, "iterations") + print("x* = {}".format(xn)) + return None + xn = xn - self.wg_lr.value * gradient + clear_output(wait=True) + display("Exceeded maximum iterations. No solution found.") + return None + + def plot(self, *args): + with self.plot_output: + clear_output(wait=True) + x0 = float(self.wg_x0.value) + x = symbols("x") + expr = sympify(self.wg_expr.value) + f = lambdify(x, sympify(expr), "numpy") + xx1 = np.arange(np.array(self.xn_list).min()*0.5, np.array(self.xn_list).max()*1.5, 0.05) + fx = f(xx1) + f_xn = f(np.array(self.xn_list)) + + fig = go.Figure() + fig.add_scatter(x=xx1, y=fx) + frames = [] + frames.append({'data':copy.deepcopy(fig['data']),'name':f'frame{0}'}) + fig.add_traces(go.Scatter(x=None, y=None, mode="lines + markers", line={"color":"#de1032", "width":5})) + frames = [go.Frame(data= [go.Scatter(x=np.array(self.xn_list)[:k], y=f_xn)],traces= [1],name=f'frame{k+2}')for k in range(len(f_xn))] + fig.update(frames=frames) + fig.update_layout(updatemenus=[dict(type="buttons",buttons=[dict(label="Play",method="animate",args=[None])])]) + fig.show() + + +class gradient_descent_2d(object): + def __init__(self): + + self.wg_expr = widgets.Text(value="(sin(x1) - 2) ** 2 + (sin(x2) - 2) ** 2", + description="Expression:") #style={'description_width': 'initial'}) + self.wg_x0 = widgets.Text(value="5,5", + description="Startpoint:") + self.wg_lr = widgets.FloatText(value="1e-1", + description="learning rate:") + self.wg_epsilon = widgets.FloatText(value="1e-5", + description="criterion:") + self.wg_max_iter = widgets.IntText(value="1000", + description="max iteration") + self.button_compute = widgets.Button(description="Compute") + self.button_plot = widgets.Button(description="Plot") + + self.compute_output = widgets.Output() + self.plot_output = widgets.Output() + self.params_lvbox = widgets.VBox([self.wg_expr, self.wg_x0, self.wg_lr]) + self.params_rvbox = widgets.VBox([self.wg_epsilon, self.wg_max_iter]) + self.params_box = widgets.HBox([self.params_lvbox, self.params_rvbox], description="Parameters") + self.button_box = widgets.HBox([self.button_compute, self.button_plot], description="operations") + self.config = widgets.VBox([self.params_box, self.button_box]) + self.initialization() + + def initialization(self): + display(self.config) + self.button_compute.on_click(self.compute) + display(self.compute_output) + self.button_plot.on_click(self.plot) + display(self.plot_output) + + def compute(self, *args): + with self.compute_output: + x0 = np.array(self.wg_x0.value.split(","), dtype=float) + xn = x0 + x1 = symbols("x1") + x2 = symbols("x2") + expr = sympify(self.wg_expr.value) + self.xn_list, self.df_list = [], [] + + for n in tqdm(range(0, self.wg_max_iter.value)): + gradient = np.array([diff(expr, x1).subs(x1, xn[0]).subs(x2, xn[1]), + diff(expr, x2).subs(x1, xn[0]).subs(x2, xn[1])], dtype=float) + self.xn_list.append(xn) + self.df_list.append(gradient) + if np.linalg.norm(gradient, ord=2) < self.wg_epsilon.value: + clear_output(wait=True) + print("Found solution of {} after".format(expr), n, "iterations") + print("x* = [{}, {}]".format(xn[0], xn[1])) + return None + xn = xn - self.wg_lr.value * gradient + clear_output(wait=True) + display("Exceeded maximum iterations. No solution found.") + return None + + def plot(self, *args): + with self.plot_output: + clear_output(wait=True) + x0 = np.array(self.wg_x0.value.split(","), dtype=float) + x1 = symbols("x1") + x2 = symbols("x2") + expr = sympify(self.wg_expr.value) + xx1 = np.arange(np.array(self.xn_list)[:, 0].min()*0.5, np.array(self.xn_list)[:, 0].max()*1.5, 0.05) + xx2 = np.arange(np.array(self.xn_list)[:, 1].min()*0.5, np.array(self.xn_list)[:, 1].max()*1.5, 0.05) + xx1, xx2 = np.meshgrid(xx1, xx2) + + f = lambdify((x1, x2), expr, "numpy") + fx = f(xx1, xx2) + f_xn = f(np.array(self.xn_list)[:, 0], np.array(self.xn_list)[:, 1]) + + fig = go.Figure() + fig.add_surface(x=xx1, y=xx2, z=fx, showscale=True, opacity=0.9) + fig.update_traces(contours_z=dict(show=True, usecolormap=True, highlightcolor="limegreen", project_z=True)) + + frames = [] + frames.append({'data':copy.deepcopy(fig['data']),'name':f'frame{0}'}) + + line_marker=dict(color="#de1032", width=5) + fig.add_traces(go.Scatter3d(x=None, y=None, z=None, mode='lines+markers', line={"color":"#de1032", "width":5})) + + frames = [go.Frame(data= [go.Scatter3d(x=np.array(self.xn_list)[:k,0], y=np.array(self.xn_list)[:k,1],z=f_xn)],traces= [1],name=f'frame{k+2}')for k in range(len(f_xn))] + fig.update(frames=frames) + fig.update_layout(updatemenus=[dict(type="buttons", + buttons=[dict(label="Play", + method="animate", + args=[None])])]) + fig.show() + + + + + \ No newline at end of file