绘制L-System的分形图

最后更新于:2022-04-01 11:16:22

# 绘制L-System的分形图 相关文档: [_L-System分形_](fractal_chaos.html#sec-lsystem) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbbe10d4.png) ``` # -*- coding: utf-8 -*- #L-System(Lindenmayer system)是一种用字符串替代产生分形图形的算法 from math import sin, cos, pi import matplotlib.pyplot as pl from matplotlib import collections class L_System(object): def __init__(self, rule): info = rule['S'] for i in range(rule['iter']): ninfo = [] for c in info: if c in rule: ninfo.append(rule[c]) else: ninfo.append(c) info = "".join(ninfo) self.rule = rule self.info = info def get_lines(self): d = self.rule['direct'] a = self.rule['angle'] p = (0.0, 0.0) l = 1.0 lines = [] stack = [] for c in self.info: if c in "Ff": r = d * pi / 180 t = p[0] + l*cos(r), p[1] + l*sin(r) lines.append(((p[0], p[1]), (t[0], t[1]))) p = t elif c == "+": d += a elif c == "-": d -= a elif c == "[": stack.append((p,d)) elif c == "]": p, d = stack[-1] del stack[-1] return lines rules = [ { "F":"F+F--F+F", "S":"F", "direct":180, "angle":60, "iter":5, "title":"Koch" }, { "X":"X+YF+", "Y":"-FX-Y", "S":"FX", "direct":0, "angle":90, "iter":13, "title":"Dragon" }, { "f":"F-f-F", "F":"f+F+f", "S":"f", "direct":0, "angle":60, "iter":7, "title":"Triangle" }, { "X":"F-[[X]+X]+F[+FX]-X", "F":"FF", "S":"X", "direct":-45, "angle":25, "iter":6, "title":"Plant" }, { "S":"X", "X":"-YF+XFX+FY-", "Y":"+XF-YFY-FX+", "direct":0, "angle":90, "iter":6, "title":"Hilbert" }, { "S":"L--F--L--F", "L":"+R-F-R+", "R":"-L+F+L-", "direct":0, "angle":45, "iter":10, "title":"Sierpinski" }, ] def draw(ax, rule, iter=None): if iter!=None: rule["iter"] = iter lines = L_System(rule).get_lines() linecollections = collections.LineCollection(lines) ax.add_collection(linecollections, autolim=True) ax.axis("equal") ax.set_axis_off() ax.set_xlim(ax.dataLim.xmin, ax.dataLim.xmax) ax.invert_yaxis() fig = pl.figure(figsize=(7,4.5)) fig.patch.set_facecolor("w") for i in xrange(6): ax = fig.add_subplot(231+i) draw(ax, rules[i]) fig.subplots_adjust(left=0,right=1,bottom=0,top=1,wspace=0,hspace=0) pl.show() ```
';

迭代函数系统的分形

最后更新于:2022-04-01 11:16:20

# 迭代函数系统的分形 相关文档: [_迭代函数系统(IFS)_](fractal_chaos.html#sec-ifs) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbb31a12.png) ``` # -*- coding: utf-8 -*- import numpy as np import matplotlib.pyplot as pl import time # 蕨类植物叶子的迭代函数和其概率值 eq1 = np.array([[0,0,0],[0,0.16,0]]) p1 = 0.01 eq2 = np.array([[0.2,-0.26,0],[0.23,0.22,1.6]]) p2 = 0.07 eq3 = np.array([[-0.15, 0.28, 0],[0.26,0.24,0.44]]) p3 = 0.07 eq4 = np.array([[0.85, 0.04, 0],[-0.04, 0.85, 1.6]]) p4 = 0.85 def ifs(p, eq, init, n): """ 进行函数迭代 p: 每个函数的选择概率列表 eq: 迭代函数列表 init: 迭代初始点 n: 迭代次数 返回值: 每次迭代所得的X坐标数组, Y坐标数组, 计算所用的函数下标 """ # 迭代向量的初始化 pos = np.ones(3, dtype=np.float) pos[:2] = init # 通过函数概率,计算函数的选择序列 p = np.add.accumulate(p) rands = np.random.rand(n) select = np.ones(n, dtype=np.int)*(n-1) for i, x in enumerate(p[::-1]): select[rands<x] = len(p)-i-1 # 结果的初始化 result = np.zeros((n,2), dtype=np.float) c = np.zeros(n, dtype=np.float) for i in xrange(n): eqidx = select[i] # 所选的函数下标 tmp = np.dot(eq[eqidx], pos) # 进行迭代 pos[:2] = tmp # 更新迭代向量 # 保存结果 result[i] = tmp c[i] = eqidx return result[:,0], result[:, 1], c start = time.clock() x, y, c = ifs([p1,p2,p3,p4],[eq1,eq2,eq3,eq4], [0,0], 100000) print time.clock() - start pl.figure(figsize=(6,6)) pl.subplot(121) pl.scatter(x, y, s=1, c="g", marker="s", linewidths=0) pl.axis("equal") pl.axis("off") pl.subplot(122) pl.scatter(x, y, s=1,c = c, marker="s", linewidths=0) pl.axis("equal") pl.axis("off") pl.subplots_adjust(left=0,right=1,bottom=0,top=1,wspace=0,hspace=0) pl.gcf().patch.set_facecolor("white") pl.show() ``` ## 迭代函数系统设计器 <object classid="clsid:D27CDB6E-AE6D-11cf-96B8-444553540000" width="600" height="370" codebase="http://active.macromedia.com/flash5/cabs/swflash.cab#version=7,0,0,0"><param name="movie" value="img/ifs.swf"> <param name="play" value="true"> <param name="loop" value="false"> <param name="wmode" value="transparent"> <param name="quality" value="high"> <embed src="img/ifs.swf" width="600" height="370" quality="high" loop="false" wmode="transparent" type="application/x-shockwave-flash" pluginspage="http://www.macromedia.com/shockwave/download/index.cgi?P1_Prod_Version=ShockwaveFlash"> </object> ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbb83791.swf) ``` # -*- coding: utf-8 -*- from enthought.traits.ui.api import * from enthought.traits.ui.menu import OKCancelButtons from enthought.traits.api import * from enthought.traits.ui.wx.editor import Editor import matplotlib # matplotlib采用WXAgg为后台,这样才能将绘图控件嵌入以wx为后台界面库的traitsUI窗口中 matplotlib.use("WXAgg") from matplotlib.backends.backend_wxagg import FigureCanvasWxAgg as FigureCanvas from matplotlib.figure import Figure import numpy as np import thread import time import wx import pickle ITER_COUNT = 4000 # 一次ifs迭代的点数 ITER_TIMES = 10 # 总共调用ifs的次数 def triangle_area(triangle): """ 计算三角形的面积 """ A = triangle[0] B = triangle[1] C = triangle[2] AB = A-B AC = A-C return np.abs(np.cross(AB,AC))/2.0 def solve_eq(triangle1, triangle2): """ 解方程,从triangle1变换到triangle2的变换系数 triangle1,2是二维数组: x0,y0 x1,y1 x2,y2 """ x0,y0 = triangle1[0] x1,y1 = triangle1[1] x2,y2 = triangle1[2] a = np.zeros((6,6), dtype=np.float) b = triangle2.reshape(-1) a[0, 0:3] = x0,y0,1 a[1, 3:6] = x0,y0,1 a[2, 0:3] = x1,y1,1 a[3, 3:6] = x1,y1,1 a[4, 0:3] = x2,y2,1 a[5, 3:6] = x2,y2,1 c = np.linalg.solve(a, b) c.shape = (2,3) return c def ifs(p, eq, init, n): """ 进行函数迭代 p: 每个函数的选择概率列表 eq: 迭代函数列表 init: 迭代初始点 n: 迭代次数 返回值: 每次迭代所得的X坐标数组, Y坐标数组, 计算所用的函数下标 """ # 迭代向量的初始化 pos = np.ones(3, dtype=np.float) pos[:2] = init # 通过函数概率,计算函数的选择序列 p = np.add.accumulate(p) rands = np.random.rand(n) select = np.ones(n, dtype=np.int)*(n-1) for i, x in enumerate(p[::-1]): select[rands<x] = len(p)-i-1 # 结果的初始化 result = np.zeros((n,2), dtype=np.float) c = np.zeros(n, dtype=np.float) for i in xrange(n): eqidx = select[i] # 所选的函数下标 tmp = np.dot(eq[eqidx], pos) # 进行迭代 pos[:2] = tmp # 更新迭代向量 # 保存结果 result[i] = tmp c[i] = eqidx return result[:,0], result[:, 1], c class _MPLFigureEditor(Editor): """ 使用matplotlib figure的traits编辑器 """ scrollable = True def init(self, parent): self.control = self._create_canvas(parent) def update_editor(self): pass def _create_canvas(self, parent): panel = wx.Panel(parent, -1, style=wx.CLIP_CHILDREN) sizer = wx.BoxSizer(wx.VERTICAL) panel.SetSizer(sizer) mpl_control = FigureCanvas(panel, -1, self.value) sizer.Add(mpl_control, 1, wx.LEFT | wx.TOP | wx.GROW) self.value.canvas.SetMinSize((10,10)) return panel class MPLFigureEditor(BasicEditorFactory): """ 相当于traits.ui中的EditorFactory,它返回真正创建控件的类 """ klass = _MPLFigureEditor class IFSTriangles(HasTraits): """ 三角形编辑器 """ version = Int(0) # 三角形更新标志 def __init__(self, ax): super(IFSTriangles, self).__init__() self.colors = ["r","g","b","c","m","y","k"] self.points = np.array([(0,0),(2,0),(2,4),(0,1),(1,1),(1,3),(1,1),(2,1),(2,3)], dtype=np.float) self.equations = self.get_eqs() self.ax = ax self.ax.set_ylim(-10,10) self.ax.set_xlim(-10,10) canvas = ax.figure.canvas # 绑定canvas的鼠标事件 canvas.mpl_connect('button_press_event', self.button_press_callback) canvas.mpl_connect('button_release_event', self.button_release_callback) canvas.mpl_connect('motion_notify_event', self.motion_notify_callback) self.canvas = canvas self._ind = None self.background = None self.update_lines() def refresh(self): """ 重新绘制所有的三角形 """ self.update_lines() self.canvas.draw() self.version += 1 def del_triangle(self): """ 删除最后一个三角形 """ self.points = self.points[:-3].copy() self.refresh() def add_triangle(self): """ 添加一个三角形 """ self.points = np.vstack((self.points, np.array([(0,0),(1,0),(0,1)],dtype=np.float))) self.refresh() def set_points(self, points): """ 直接设置三角形定点 """ self.points = points.copy() self.refresh() def get_eqs(self): """ 计算所有的仿射方程 """ eqs = [] for i in range(1,len(self.points)/3): eqs.append( solve_eq( self.points[:3,:], self.points[i*3:i*3+3,:]) ) return eqs def get_areas(self): """ 通过三角形的面积计算仿射方程的迭代概率 """ areas = [] for i in range(1, len(self.points)/3): areas.append( triangle_area(self.points[i*3:i*3+3,:]) ) s = sum(areas) return [x/s for x in areas] def update_lines(self): """ 重新绘制所有的三角形 """ del self.ax.lines[:] for i in xrange(0,len(self.points),3): color = self.colors[i/3%len(self.colors)] x0, x1, x2 = self.points[i:i+3, 0] y0, y1, y2 = self.points[i:i+3, 1] type = color+"%so" if i==0: linewidth = 3 else: linewidth = 1 self.ax.plot([x0,x1],[y0,y1], type % "-", linewidth=linewidth) self.ax.plot([x1,x2],[y1,y2], type % "--", linewidth=linewidth) self.ax.plot([x0,x2],[y0,y2], type % ":", linewidth=linewidth) self.ax.set_ylim(-10,10) self.ax.set_xlim(-10,10) def button_release_callback(self, event): """ 鼠标按键松开事件 """ self._ind = None def button_press_callback(self, event): """ 鼠标按键按下事件 """ if event.inaxes!=self.ax: return if event.button != 1: return self._ind = self.get_ind_under_point(event.xdata, event.ydata) def get_ind_under_point(self, mx, my): """ 找到距离mx, my最近的顶点 """ for i, p in enumerate(self.points): if abs(mx-p[0]) < 0.5 and abs(my-p[1])< 0.5: return i return None def motion_notify_callback(self, event): """ 鼠标移动事件 """ self.event = event if self._ind is None: return if event.inaxes != self.ax: return if event.button != 1: return x,y = event.xdata, event.ydata #更新定点坐标 self.points[self._ind,:] = [x, y] i = self._ind / 3 * 3 # 更新顶点对应的三角形线段 x0, x1, x2 = self.points[i:i+3, 0] y0, y1, y2 = self.points[i:i+3, 1] self.ax.lines[i].set_data([x0,x1],[y0,y1]) self.ax.lines[i+1].set_data([x1,x2],[y1,y2]) self.ax.lines[i+2].set_data([x0,x2],[y0,y2]) # 背景为空时,捕捉背景 if self.background == None: self.ax.clear() self.ax.set_axis_off() self.canvas.draw() self.background = self.canvas.copy_from_bbox(self.ax.bbox) self.update_lines() # 快速绘制所有三角形 self.canvas.restore_region(self.background) #恢复背景 # 绘制所有三角形 for line in self.ax.lines: self.ax.draw_artist(line) self.canvas.blit(self.ax.bbox) self.version += 1 class AskName(HasTraits): name = Str("") view = View( Item("name", label = u"名称"), kind = "modal", buttons = OKCancelButtons ) class IFSHandler(Handler): """ 在界面显示之前需要初始化的内容 """ def init(self, info): info.object.init_gui_component() return True class IFSDesigner(HasTraits): figure = Instance(Figure) # 控制绘图控件的Figure对象 ifs_triangle = Instance(IFSTriangles) add_button = Button(u"添加三角形") del_button = Button(u"删除三角形") save_button = Button(u"保存当前IFS") unsave_button = Button(u"删除当前IFS") clear = Bool(True) exit = Bool(False) ifs_names = List() ifs_points = List() current_name = Str view = View( VGroup( HGroup( Item("add_button"), Item("del_button"), Item("current_name", editor = EnumEditor(name="object.ifs_names")), Item("save_button"), Item("unsave_button"), show_labels = False ), Item("figure", editor=MPLFigureEditor(), show_label=False, width=600), ), resizable = True, height = 350, width = 600, title = u"迭代函数系统设计器", handler = IFSHandler() ) def _current_name_changed(self): self.ifs_triangle.set_points( self.ifs_points[ self.ifs_names.index(self.current_name) ] ) def _add_button_fired(self): """ 添加三角形按钮事件处理 """ self.ifs_triangle.add_triangle() def _del_button_fired(self): self.ifs_triangle.del_triangle() def _unsave_button_fired(self): if self.current_name in self.ifs_names: index = self.ifs_names.index(self.current_name) del self.ifs_names[index] del self.ifs_points[index] self.save_data() def _save_button_fired(self): """ 保存按钮处理 """ ask = AskName(name = self.current_name) if ask.configure_traits(): if ask.name not in self.ifs_names: self.ifs_names.append( ask.name ) self.ifs_points.append( self.ifs_triangle.points.copy() ) else: index = self.ifs_names.index(ask.name) self.ifs_names[index] = ask.name self.ifs_points[index] = self.ifs_triangle.points.copy() self.save_data() def save_data(self): with file("IFS.data", "wb") as f: pickle.dump(self.ifs_names[:], f) # ifs_names不是list,因此需要先转换为list for data in self.ifs_points: np.save(f, data) # 保存多个数组 def ifs_calculate(self): """ 在别的线程中计算 """ def draw_points(x, y, c): if len(self.ax2.collections) < ITER_TIMES: try: self.ax2.scatter(x, y, s=1, c=c, marker="s", linewidths=0) self.ax2.set_axis_off() self.ax2.axis("equal") self.figure.canvas.draw() except: pass def clear_points(): self.ax2.clear() while 1: try: if self.exit == True: break if self.clear == True: self.clear = False self.initpos = [0, 0] # 不绘制迭代的初始100个点 x, y, c = ifs( self.ifs_triangle.get_areas(), self.ifs_triangle.get_eqs(), self.initpos, 100) self.initpos = [x[-1], y[-1]] self.ax2.clear() x, y, c = ifs( self.ifs_triangle.get_areas(), self.ifs_triangle.get_eqs(), self.initpos, ITER_COUNT) if np.max(np.abs(x)) < 1000000 and np.max(np.abs(y)) < 1000000: self.initpos = [x[-1], y[-1]] wx.CallAfter( draw_points, x, y, c ) time.sleep(0.05) except: pass @on_trait_change("ifs_triangle.version") def on_ifs_version_changed(self): """ 当三角形更新时,重新绘制所有的迭代点 """ self.clear = True def _figure_default(self): """ figure属性的缺省值,直接创建一个Figure对象 """ figure = Figure() self.ax = figure.add_subplot(121) self.ax2 = figure.add_subplot(122) self.ax2.set_axis_off() self.ax.set_axis_off() figure.subplots_adjust(left=0,right=1,bottom=0,top=1,wspace=0,hspace=0) figure.patch.set_facecolor("w") return figure def init_gui_component(self): self.ifs_triangle = IFSTriangles(self.ax) self.figure.canvas.draw() thread.start_new_thread( self.ifs_calculate, ()) try: with file("ifs.data","rb") as f: self.ifs_names = pickle.load(f) self.ifs_points = [] for i in xrange(len(self.ifs_names)): self.ifs_points.append(np.load(f)) if len(self.ifs_names) > 0: self.current_name = self.ifs_names[-1] except: pass designer = IFSDesigner() designer.configure_traits() designer.exit = True ```
';

绘制Mandelbrot集合

最后更新于:2022-04-01 11:16:17

# 绘制Mandelbrot集合 相关文档: [_Mandelbrot集合_](fractal_chaos.html#sec-mandelbrot) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbaae9d4.png) ## 纯Python计算版本 ``` # -*- coding: utf-8 -*- import numpy as np import pylab as pl import time from matplotlib import cm def iter_point(c): z = c for i in xrange(1, 100): # 最多迭代100次 if abs(z)>2: break # 半径大于2则认为逃逸 z = z*z+c return i # 返回迭代次数 def draw_mandelbrot(cx, cy, d): """ 绘制点(cx, cy)附近正负d的范围的Mandelbrot """ x0, x1, y0, y1 = cx-d, cx+d, cy-d, cy+d y, x = np.ogrid[y0:y1:200j, x0:x1:200j] c = x + y*1j start = time.clock() mandelbrot = np.frompyfunc(iter_point,1,1)(c).astype(np.float) print "time=",time.clock() - start pl.imshow(mandelbrot, cmap=cm.Blues_r, extent=[x0,x1,y0,y1]) pl.gca().set_axis_off() x,y = 0.27322626, 0.595153338 pl.subplot(231) draw_mandelbrot(-0.5,0,1.5) for i in range(2,7): pl.subplot(230+i) draw_mandelbrot(x, y, 0.2**(i-1)) pl.subplots_adjust(0.02, 0, 0.98, 1, 0.02, 0) pl.show() ``` ## Weave版本 ``` # -*- coding: utf-8 -*- import numpy as np import pylab as pl import time import scipy.weave as weave from matplotlib import cm def weave_iter_point(c): code = """ std::complex<double> z; int i; z = c; for(i=1;i<100;i++) { if(std::abs(z) > 2) break; z = z*z+c; } return_val=i; """ f = weave.inline(code, ["c"], compiler="gcc") return f def draw_mandelbrot(cx, cy, d,N=200): """ 绘制点(cx, cy)附近正负d的范围的Mandelbrot """ x0, x1, y0, y1 = cx-d, cx+d, cy-d, cy+d y, x = np.ogrid[y0:y1:N*1j, x0:x1:N*1j] c = x + y*1j start = time.clock() mandelbrot = np.frompyfunc(weave_iter_point,1,1)(c).astype(np.float) print "time=",time.clock() - start pl.imshow(mandelbrot, cmap=cm.Blues_r, extent=[x0,x1,y0,y1]) pl.gca().set_axis_off() x,y = 0.27322626, 0.595153338 pl.subplot(231) draw_mandelbrot(-0.5,0,1.5) for i in range(2,7): pl.subplot(230+i) draw_mandelbrot(x, y, 0.2**(i-1)) pl.subplots_adjust(0.02, 0, 0.98, 1, 0.02, 0.02) pl.show() ``` ## NumPy加速版本 ``` # -*- coding: utf-8 -*- import numpy as np import pylab as pl import time from matplotlib import cm def draw_mandelbrot(cx, cy, d, N=200): """ 绘制点(cx, cy)附近正负d的范围的Mandelbrot """ global mandelbrot x0, x1, y0, y1 = cx-d, cx+d, cy-d, cy+d y, x = np.ogrid[y0:y1:N*1j, x0:x1:N*1j] c = x + y*1j # 创建X,Y轴的坐标数组 ix, iy = np.mgrid[0:N,0:N] # 创建保存mandelbrot图的二维数组,缺省值为最大迭代次数 mandelbrot = np.ones(c.shape, dtype=np.int)*100 # 将数组都变成一维的 ix.shape = -1 iy.shape = -1 c.shape = -1 z = c.copy() # 从c开始迭代,因此开始的迭代次数为1 start = time.clock() for i in xrange(1,100): # 进行一次迭代 z *= z z += c # 找到所有结果逃逸了的点 tmp = np.abs(z) > 2.0 # 将这些逃逸点的迭代次数赋值给mandelbrot图 mandelbrot[ix[tmp], iy[tmp]] = i # 找到所有没有逃逸的点 np.logical_not(tmp, tmp) # 更新ix, iy, c, z只包含没有逃逸的点 ix,iy,c,z = ix[tmp], iy[tmp], c[tmp],z[tmp] if len(z) == 0: break print "time=",time.clock() - start pl.imshow(mandelbrot, cmap=cm.Blues_r, extent=[x0,x1,y0,y1]) pl.gca().set_axis_off() x,y = 0.27322626, 0.595153338 pl.subplot(231) draw_mandelbrot(-0.5,0,1.5) for i in range(2,7): pl.subplot(230+i) draw_mandelbrot(x, y, 0.2**(i-1)) pl.subplots_adjust(0.02, 0, 0.98, 1, 0.02, 0) pl.show() ``` ## 平滑版本 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbb05dd5.png) ``` # -*- coding: utf-8 -*- import numpy as np import pylab as pl from math import log from matplotlib import cm escape_radius = 10 iter_num = 20 def smooth_iter_point(c): z = c for i in xrange(1, iter_num): if abs(z)>escape_radius: break z = z*z+c absz = abs(z) if absz > 2.0: mu = i - log(log(abs(z),2),2) else: mu = i return mu # 返回正规化的迭代次数 def iter_point(c): z = c for i in xrange(1, iter_num): if abs(z)>escape_radius: break z = z*z+c return i def draw_mandelbrot(cx, cy, d, N=200): global mandelbrot """ 绘制点(cx, cy)附近正负d的范围的Mandelbrot """ x0, x1, y0, y1 = cx-d, cx+d, cy-d, cy+d y, x = np.ogrid[y0:y1:N*1j, x0:x1:N*1j] c = x + y*1j mand = np.frompyfunc(iter_point,1,1)(c).astype(np.float) smooth_mand = np.frompyfunc(smooth_iter_point,1,1)(c).astype(np.float) pl.subplot(121) pl.gca().set_axis_off() pl.imshow(mand, cmap=cm.Blues_r, extent=[x0,x1,y0,y1]) pl.subplot(122) pl.imshow(smooth_mand, cmap=cm.Blues_r, extent=[x0,x1,y0,y1]) pl.gca().set_axis_off() draw_mandelbrot(-0.5,0,1.5,300) pl.subplots_adjust(0.02, 0, 0.98, 1, 0.02, 0) pl.show() ```
';

双摆系统的动画模拟

最后更新于:2022-04-01 11:16:15

# 双摆系统的动画模拟 相关文档: [_单摆和双摆模拟_](double_pendulum.html) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bba2d608.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bba3f9f8.png) ## 用odeint解双摆系统 文件名: double_pendulum_odeint.py ``` # -*- coding: utf-8 -*- from math import sin,cos import numpy as np from scipy.integrate import odeint g = 9.8 class DoublePendulum(object): def __init__(self, m1, m2, l1, l2): self.m1, self.m2, self.l1, self.l2 = m1, m2, l1, l2 self.init_status = np.array([0.0,0.0,0.0,0.0]) def equations(self, w, t): """ 微分方程公式 """ m1, m2, l1, l2 = self.m1, self.m2, self.l1, self.l2 th1, th2, v1, v2 = w dth1 = v1 dth2 = v2 #eq of th1 a = l1*l1*(m1+m2) # dv1 parameter b = l1*m2*l2*cos(th1-th2) # dv2 paramter c = l1*(m2*l2*sin(th1-th2)*dth2*dth2 + (m1+m2)*g*sin(th1)) #eq of th2 d = m2*l2*l1*cos(th1-th2) # dv1 parameter e = m2*l2*l2 # dv2 parameter f = m2*l2*(-l1*sin(th1-th2)*dth1*dth1 + g*sin(th2)) dv1, dv2 = np.linalg.solve([[a,b],[d,e]], [-c,-f]) return np.array([dth1, dth2, dv1, dv2]) def double_pendulum_odeint(pendulum, ts, te, tstep): """ 对双摆系统的微分方程组进行数值求解,返回两个小球的X-Y坐标 """ t = np.arange(ts, te, tstep) track = odeint(pendulum.equations, pendulum.init_status, t) th1_array, th2_array = track[:,0], track[:, 1] l1, l2 = pendulum.l1, pendulum.l2 x1 = l1*np.sin(th1_array) y1 = -l1*np.cos(th1_array) x2 = x1 + l2*np.sin(th2_array) y2 = y1 - l2*np.cos(th2_array) pendulum.init_status = track[-1,:].copy() #将最后的状态赋给pendulum return [x1, y1, x2, y2] if __name__ == "__main__": import matplotlib.pyplot as pl pendulum = DoublePendulum(1.0, 2.0, 1.0, 2.0) th1, th2 = 1.0, 2.0 pendulum.init_status[:2] = th1, th2 x1, y1, x2, y2 = double_pendulum_odeint(pendulum, 0, 30, 0.02) pl.plot(x1,y1, label = u"上球") pl.plot(x2,y2, label = u"下球") pl.title(u"双摆系统的轨迹, 初始角度=%s,%s" % (th1, th2)) pl.legend() pl.axis("equal") pl.show() ``` ## 摆动动画 文件名: double_pendulum_animation.py ``` # -*- coding: utf-8 -*- import matplotlib matplotlib.use('WXAgg') # do this before importing pylab import matplotlib.pyplot as pl from double_pendulum_odeint import double_pendulum_odeint, DoublePendulum fig = pl.figure(figsize=(4,4)) line1, = pl.plot([0,0], [0,0], "-o") line2, = pl.plot([0,0], [0,0], "-o") pl.axis("equal") pl.xlim(-4,4) pl.ylim(-4,2) pendulum = DoublePendulum(1.0, 2.0, 1.0, 2.0) pendulum.init_status[:] = 1.0, 2.0, 0, 0 x1, y1, x2, y2 = [],[],[],[] idx = 0 def update_line(event): global x1, x2, y1, y2, idx if idx == len(x1): x1, y1, x2, y2 = double_pendulum_odeint(pendulum, 0, 1, 0.05) idx = 0 line1.set_xdata([0, x1[idx]]) line1.set_ydata([0, y1[idx]]) line2.set_xdata([x1[idx], x2[idx]]) line2.set_ydata([y1[idx], y2[idx]]) fig.canvas.draw() idx += 1 import wx id = wx.NewId() actor = fig.canvas.manager.frame timer = wx.Timer(actor, id=id) timer.Start(1) wx.EVT_TIMER(actor, id, update_line) pl.show() ```
';

单摆摆动周期的计算

最后更新于:2022-04-01 11:16:13

# 单摆摆动周期的计算 相关文档: [_单摆和双摆模拟_](double_pendulum.html) 本程序利用odeint和fsolve计算单摆的摆动周期,并且和精确值进行比较。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb89b52f.png) ``` # -*- coding: utf-8 -*- from math import sin, sqrt import numpy as np from scipy.integrate import odeint from scipy.optimize import fsolve import pylab as pl from scipy.special import ellipk g = 9.8 def pendulum_equations(w, t, l): th, v = w dth = v dv = - g/l * sin(th) return dth, dv def pendulum_th(t, l, th0): track = odeint(pendulum_equations, (th0, 0), [0, t], args=(l,)) return track[-1, 0] def pendulum_period(l, th0): t0 = 2*np.pi*sqrt( l/g ) / 4 t = fsolve( pendulum_th, t0, args = (l, th0) ) return t*4 ths = np.arange(0, np.pi/2.0, 0.01) periods = [pendulum_period(1, th) for th in ths] periods2 = 4*sqrt(1.0/g)*ellipk(np.sin(ths/2)**2) # 计算单摆周期的精确值 pl.plot(ths, periods, label = u"fsolve计算的单摆周期", linewidth=4.0) pl.plot(ths, periods2, "r", label = u"单摆周期精确值", linewidth=2.0) pl.legend(loc='upper left') pl.title(u"长度为1米单摆:初始摆角-摆动周期") pl.xlabel(u"初始摆角(弧度)") pl.ylabel(u"摆动周期(秒)") pl.show() ```
';

二次均衡器设计

最后更新于:2022-04-01 11:16:11

# 二次均衡器设计 相关文档: [_数字信号系统_](filters.html) 用Traits.UI和Chaco制作的二次均衡器的设计工具,用户可以任意添加二次滤波器,并且调整其中心频率、增益和Q值,并即时查看组合之后的最终频率响应。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb44bad8.png) ``` # -*- coding: utf-8 -*- import math from enthought.traits.api import Float, HasTraits, List, Array, on_trait_change, Instance, Range, Button from enthought.traits.ui.api import View, TableEditor, Item, Group, HGroup, VGroup, HSplit, ScrubberEditor, EnumEditor from enthought.traits.ui.table_column import ObjectColumn from enthought.chaco.api import Plot, AbstractPlotData, ArrayPlotData, VPlotContainer from enthought.chaco.tools.api import PanTool, ZoomTool from enthought.enable.api import Component, ComponentEditor from enthought.pyface.api import FileDialog, OK import pickle import numpy as np SAMPLING_RATE = 44100.0 # 取样频率 WORN = 1000 # 频率响应曲线的点数 # 对数圆频率数组 W = np.logspace(np.log10(10/SAMPLING_RATE*np.pi), np.log10(np.pi), WORN) # 对数频率数组 FREQS = W / 2 / np.pi * SAMPLING_RATE # 候选频率 EQ_FREQS = [20.0,25.2,31.7,40.0,50.4,63.5,80.0,100.8, 127.0,160.0,201.6,254.0,320.0,403.2,508.0,640.0, 806.3,1015.9,1280.0,1612.7,2031.9,2560.0,3225.4, 4063.7, 5120.0, 6450.8, 8127.5, 10240.0,12901.6, 16255.0,20480.0,] def scrubber(inc): '''创建不同增量的ScrubberEditor''' return ScrubberEditor( hover_color = 0xFFFFFF, active_color = 0xA0CD9E, border_color = 0x808080, increment = inc ) def myfreqz(b, a, w): '''计算滤波器在w个点的频率响应''' zm1 = np.exp(-1j*w) h = np.polyval(b[::-1], zm1) / np.polyval(a[::-1], zm1) return h def design_equalizer(freq, Q, gain, Fs): '''设计二次均衡滤波器的系数''' A = 10**(gain/40.0) w0 = 2*math.pi*freq/Fs alpha = math.sin(w0) / 2 / Q b0 = 1 + alpha * A b1 = -2*math.cos(w0) b2 = 1 - alpha * A a0 = 1 + alpha / A a1 = -2*math.cos(w0) a2 = 1 - alpha / A return [b0/a0,b1/a0,b2/a0], [1.0, a1/a0, a2/a0] class Equalizer(HasTraits): freq = Range(10.0, SAMPLING_RATE/2, 1000) Q = Range(0.1, 10.0, 1.0) gain = Range(-24.0, 24.0, 0) a = List(Float, [1.0,0.0,0.0]) b = List(Float, [1.0,0.0,0.0]) h = Array(dtype=np.complex, transient = True) def __init__(self): super(Equalizer, self).__init__() self.design_parameter() @on_trait_change("freq,Q,gain") def design_parameter(self): '''设计系数并计算频率响应''' try: self.b, self.a = design_equalizer(self.freq, self.Q, self.gain, SAMPLING_RATE) except: self.b, self.a = [1.0,0.0,0.0], [1.0,0.0,0.0] self.h = myfreqz(self.b, self.a, W) def export_parameters(self, f): '''输出滤波器系数为C语言数组''' tmp = self.b[0], self.b[1], self.b[2], self.a[1], self.a[2], self.freq, self.Q, self.gain f.write("{%s,%s,%s,%s,%s}, // %s,%s,%s\n" % tmp) class Equalizers(HasTraits): eqs = List(Equalizer, [Equalizer()]) h = Array(dtype=np.complex, transient = True) # Equalizer列表eqs的编辑器定义 table_editor = TableEditor( columns = [ ObjectColumn(name="freq", width=0.4, style="readonly"), ObjectColumn(name="Q", width=0.3, style="readonly"), ObjectColumn(name="gain", width=0.3, style="readonly"), ], deletable = True, sortable = True, auto_size = False, show_toolbar = True, edit_on_first_click = False, orientation = 'vertical', edit_view = View( Group( Item("freq", editor=EnumEditor(values=EQ_FREQS)), Item("freq", editor=scrubber(1.0)), Item("Q", editor=scrubber(0.01)), Item("gain", editor=scrubber(0.1)), show_border=True, ), resizable = True ), row_factory = Equalizer ) view = View( Item("eqs", show_label=False, editor=table_editor), width = 0.25, height = 0.5, resizable = True ) @on_trait_change("eqs.h") def recalculate_h(self): '''计算多组均衡器级联时的频率响应''' try: tmp = np.array([eq.h for eq in self.eqs if eq.h != None and len(eq.h) == len(W)]) self.h = np.prod(tmp, axis=0) except: pass def export(self, path): '''将均衡器的系数输出为C语言文件''' f = file(path, "w") f.write("double EQ_PARS[][5] = {\n") f.write("//b0,b1,b2,a0,a1 // frequency, Q, gain\n") for eq in self.eqs: eq.export_parameters(f) f.write("};\n") f.close() class EqualizerDesigner(HasTraits): '''均衡器设计器的主界面''' equalizers = Instance(Equalizers) # 保存绘图数据的对象 plot_data = Instance(AbstractPlotData) # 绘制波形图的容器 container = Instance(Component) plot_gain = Instance(Component) plot_phase = Instance(Component) save_button = Button("Save") load_button = Button("Load") export_button = Button("Export") view = View( VGroup( HGroup( Item("load_button"), Item("save_button"), Item("export_button"), show_labels = False ), HSplit( VGroup( Item("equalizers", style="custom", show_label=False), show_border=True, ), Item("container", editor=ComponentEditor(size=(800, 300)), show_label=False), ) ), resizable = True, width = 800, height = 500, title = u"Equalizer Designer" ) def _create_plot(self, data, name, type="line"): p = Plot(self.plot_data) p.plot(data, name=name, title=name, type=type) p.tools.append(PanTool(p)) zoom = ZoomTool(component=p, tool_mode="box", always_on=False) p.overlays.append(zoom) p.title = name p.index_scale = "log" return p def __init__(self): super(EqualizerDesigner, self).__init__() self.plot_data = ArrayPlotData(f=FREQS, gain=[], phase=[]) self.plot_gain = self._create_plot(("f", "gain"), "Gain(dB)") self.plot_phase = self._create_plot(("f", "phase"), "Phase(degree)") self.container = VPlotContainer() self.container.add( self.plot_phase ) self.container.add( self.plot_gain ) self.plot_gain.padding_bottom = 20 self.plot_phase.padding_top = 20 def _equalizers_default(self): return Equalizers() @on_trait_change("equalizers.h") def redraw(self): gain = 20*np.log10(np.abs(self.equalizers.h)) phase = np.angle(self.equalizers.h, deg=1) self.plot_data.set_data("gain", gain) self.plot_data.set_data("phase", phase) def _save_button_fired(self): dialog = FileDialog(action="save as", wildcard='EQ files (*.eq)|*.eq') result = dialog.open() if result == OK: f = file(dialog.path, "wb") pickle.dump( self.equalizers , f) f.close() def _load_button_fired(self): dialog = FileDialog(action="open", wildcard='EQ files (*.eq)|*.eq') result = dialog.open() if result == OK: f = file(dialog.path, "rb") self.equalizers = pickle.load(f) f.close() def _export_button_fired(self): dialog = FileDialog(action="save as", wildcard='c files (*.c)|*.c') result = dialog.open() if result == OK: self.equalizers.export(dialog.path) win = EqualizerDesigner() win.configure_traits() ```
';

FFT卷积的速度比较

最后更新于:2022-04-01 11:16:08

# FFT卷积的速度比较 相关文档: [_频域信号处理_](frequency_process.html) 直接卷积的复杂度为O(N*N),FFT的复杂度为O(N*log(N)),此程序分别计算直接卷积和快速卷积的耗时曲线。请注意Y轴为每点的平均运算时间。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb5a73d7.png) ``` # -*- coding: utf-8 -*- import numpy as np import timeit def fft_convolve(a,b): n = len(a)+len(b)-1 N = 2**(int(np.log2(n))+1) A = np.fft.fft(a, N) B = np.fft.fft(b, N) return np.fft.ifft(A*B)[:n] if __name__ == "__main__": from pylab import * n_list = [] t1_list = [] t2_list = [] for n in xrange(4, 14): N = 2**n count = 10000**2 / N**2 if count > 10000: count = 10000 setup = """ import numpy as np from __main__ import fft_convolve a = np.random.rand(%s) b = np.random.rand(%s) """ % (N, N) t1 = timeit.timeit("np.convolve(a,b)", setup, number=count) t2 = timeit.timeit("fft_convolve(a,b)", setup, number=count) t1_list.append(t1*1000/count/N) t2_list.append(t2*1000/count/N) n_list.append(N) figure(figsize=(8,4)) plot(n_list, t1_list, label=u"直接卷积") plot(n_list, t2_list, label=u"FFT卷积") legend() title(u"卷积的计算时间") ylabel(u"计算时间(ms/point)") xlabel(u"长度") xlim(min(n_list),max(n_list)) show() ```
';

频谱泄漏和hann窗

最后更新于:2022-04-01 11:16:06

# 频谱泄漏和hann窗 相关文档: [_频域信号处理_](frequency_process.html) 对于8kHz取样频率的200Hz 300Hz的叠加波形进行512点FFT计算其频谱,比较矩形窗和hann窗的频谱泄漏。 ``` # -*- coding: utf-8 -*- #用hann窗降低频谱泄漏 # import numpy as np import pylab as pl import scipy.signal as signal sampling_rate = 8000 fft_size = 512 t = np.arange(0, 1.0, 1.0/sampling_rate) x = np.sin(2*np.pi*200*t) + 2*np.sin(2*np.pi*300*t) xs = x[:fft_size] ys = xs * signal.hann(fft_size, sym=0) xf = np.fft.rfft(xs)/fft_size yf = np.fft.rfft(ys)/fft_size freqs = np.linspace(0, sampling_rate/2, fft_size/2+1) xfp = 20*np.log10(np.clip(np.abs(xf), 1e-20, 1e100)) yfp = 20*np.log10(np.clip(np.abs(yf), 1e-20, 1e100)) pl.figure(figsize=(8,4)) pl.title(u"200Hz和300Hz的波形和频谱") pl.plot(freqs, xfp, label=u"矩形窗") pl.plot(freqs, yfp, label=u"hann窗") pl.legend() pl.xlabel(u"频率(Hz)") a = pl.axes([.4, .2, .4, .4]) a.plot(freqs, xfp, label=u"矩形窗") a.plot(freqs, yfp, label=u"hann窗") a.set_xlim(100, 400) a.set_ylim(-40, 0) pl.show() ```
';

三维标量场观察器

最后更新于:2022-04-01 11:16:04

# 三维标量场观察器 相关文档: [_将Mayavi嵌入到界面中_](mlab_and_mayavi.html#sec-mayavi-embed) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bad70c9b.png) ``` # -*- coding: utf-8 -*- import numpy as np from numpy import * from enthought.traits.api import * from enthought.traits.ui.api import * from enthought.tvtk.pyface.scene_editor import SceneEditor from enthought.mayavi.tools.mlab_scene_model import MlabSceneModel from enthought.mayavi.core.ui.mayavi_scene import MayaviScene class FieldViewer(HasTraits): """三维标量场观察器""" # 三个轴的取值范围 x0, x1 = Float(-5), Float(5) y0, y1 = Float(-5), Float(5) z0, z1 = Float(-5), Float(5) points = Int(50) # 分割点数 autocontour = Bool(True) # 是否自动计算等值面 v0, v1 = Float(0.0), Float(1.0) # 等值面的取值范围 contour = Range("v0", "v1", 0.5) # 等值面的值 function = Str("x*x*0.5 + y*y + z*z*2.0") # 标量场函数 plotbutton = Button(u"描画") scene = Instance(MlabSceneModel, ()) # mayavi场景 view = View( HSplit( VGroup( "x0","x1","y0","y1","z0","z1", Item('points', label=u"点数"), Item('autocontour', label=u"自动等值"), Item('plotbutton', show_label=False), ), VGroup( Item(name='scene', editor=SceneEditor(scene_class=MayaviScene), # 设置mayavi的编辑器 resizable=True, height=300, width=350 ), 'function', Item('contour', editor=RangeEditor(format="%1.2f", low_name="v0", high_name="v1") ), show_labels=False ) ), width = 500, resizable=True, title=u"三维标量场观察器" ) def _plotbutton_fired(self): self.plot() def _autocontour_changed(self): "自动计算等值平面的设置改变事件响应" if hasattr(self, "g"): self.g.contour.auto_contours = self.autocontour if not self.autocontour: self._contour_changed() def _contour_changed(self): "等值平面的值改变事件响应" if hasattr(self, "g"): if not self.g.contour.auto_contours: self.g.contour.contours = [self.contour] def plot(self): "绘制场景" # 产生三维网格 x, y, z = mgrid[ self.x0:self.x1:1j*self.points, self.y0:self.y1:1j*self.points, self.z0:self.z1:1j*self.points] scalars = eval(self.function) # 根据函数计算标量场的值 self.scene.mlab.clf() # 清空当前场景 # 绘制等值平面 g = self.scene.mlab.contour3d(x, y, z, scalars, contours=8, transparent=True) g.contour.auto_contours = self.autocontour self.scene.mlab.axes() # 添加坐标轴 # 添加一个X-Y的切面 s = self.scene.mlab.pipeline.scalar_cut_plane(g) cutpoint = (self.x0+self.x1)/2, (self.y0+self.y1)/2, (self.z0+self.z1)/2 s.implicit_plane.normal = (0,0,1) # x cut s.implicit_plane.origin = cutpoint self.g = g self.scalars = scalars # 计算标量场的值的范围 self.v0 = np.min(scalars) self.v1 = np.max(scalars) app = FieldViewer() app.configure_traits() ```
';

NLMS算法的模拟测试

最后更新于:2022-04-01 11:16:02

# NLMS算法的模拟测试 相关文档: [_自适应滤波器和NLMS模拟_](fast_nlms_in_python.html) 测试NLMS在系统辨识、信号预测和信号均衡方面的应用。 ``` # -*- coding: utf-8 -*- # filename: nlms_test.py import numpy as np import pylab as pl import nlms_numpy import scipy.signal # 随机产生FIR滤波器的系数,长度为length, 延时为delay, 指数衰减 def make_path(delay, length): path_length = length - delay h = np.zeros(length, np.float64) h[delay:] = np.random.standard_normal(path_length) * np.exp( np.linspace(0, -4, path_length) ) h /= np.sqrt(np.sum(h*h)) return h def plot_converge(y, u, label=""): size = len(u) avg_number = 200 e = np.power(y[:size] - u, 2) tmp = e[:int(size/avg_number)*avg_number] tmp.shape = -1, avg_number avg = np.average( tmp, axis=1 ) pl.plot(np.linspace(0, size, len(avg)), 10*np.log10(avg), linewidth=2.0, label=label) def diff_db(h0, h): return 10*np.log10(np.sum((h0-h)*(h0-h)) / np.sum(h0*h0)) # 用NLMS进行系统辨识的模拟, 未知系统的传递函数为h0, 使用的参照信号为x def sim_system_identify(nlms, x, h0, step_size, noise_scale): y = np.convolve(x, h0) d = y + np.random.standard_normal(len(y)) * noise_scale # 添加白色噪声的外部干扰 h = np.zeros(len(h0), np.float64) # 自适应滤波器的长度和未知系统长度相同,初始值为0 u = nlms( x, d, h, step_size ) return y, u, h def system_identify_test1(): h0 = make_path(32, 256) # 随机产生一个未知系统的传递函数 x = np.random.standard_normal(10000) # 参照信号为白噪声 y, u, h = sim_system_identify(nlms_numpy.nlms, x, h0, 0.5, 0.1) print diff_db(h0, h) pl.figure( figsize=(8, 6) ) pl.subplot(211) pl.subplots_adjust(hspace=0.4) pl.plot(h0, c="r") pl.plot(h, c="b") pl.title(u"未知系统和收敛后的滤波器的系数比较") pl.subplot(212) plot_converge(y, u) pl.title(u"自适应滤波器收敛特性") pl.xlabel("Iterations (samples)") pl.ylabel("Converge Level (dB)") pl.show() def system_identify_test2(): h0 = make_path(32, 256) # 随机产生一个未知系统的传递函数 x = np.random.standard_normal(20000) # 参照信号为白噪声 pl.figure(figsize=(8,4)) for step_size in np.arange(0.1, 1.0, 0.2): y, u, h = sim_system_identify(nlms_numpy.nlms, x, h0, step_size, 0.1) plot_converge(y, u, label=u"μ=%s" % step_size) pl.title(u"更新系数和收敛特性的关系") pl.xlabel("Iterations (samples)") pl.ylabel("Converge Level (dB)") pl.legend() pl.show() def system_identify_test3(): h0 = make_path(32, 256) # 随机产生一个未知系统的传递函数 x = np.random.standard_normal(20000) # 参照信号为白噪声 pl.figure(figsize=(8,4)) for noise_scale in [0.05, 0.1, 0.2, 0.4, 0.8]: y, u, h = sim_system_identify(nlms_numpy.nlms, x, h0, 0.5, noise_scale) plot_converge(y, u, label=u"noise=%s" % noise_scale) pl.title(u"外部干扰和收敛特性的关系") pl.xlabel("Iterations (samples)") pl.ylabel("Converge Level (dB)") pl.legend() pl.show() def sim_signal_equation(nlms, x, h0, D, step_size, noise_scale): d = x[:-D] x = x[D:] y = np.convolve(x, h0)[:len(x)] h = np.zeros(2*len(h0)+2*D, np.float64) y += np.random.standard_normal(len(y)) * noise_scale u = nlms(y, d, h, step_size) return h def signal_equation_test1(): h0 = make_path(5, 64) D = 128 length = 20000 data = np.random.standard_normal(length+D) h = sim_signal_equation(nlms_numpy.nlms, data, h0, D, 0.5, 0.1) pl.figure(figsize=(8,4)) pl.plot(h0, label=u"未知系统") pl.plot(h, label=u"自适应滤波器") pl.plot(np.convolve(h0, h), label=u"二者卷积") pl.title(u"信号均衡演示") pl.legend() w0, H0 = scipy.signal.freqz(h0, worN = 1000) w, H = scipy.signal.freqz(h, worN = 1000) pl.figure(figsize=(8,4)) pl.plot(w0, 20*np.log10(np.abs(H0)), w, 20*np.log10(np.abs(H))) pl.title(u"未知系统和自适应滤波器的振幅特性") pl.xlabel(u"圆频率") pl.ylabel(u"振幅(dB)") pl.show() signal_equation_test1() ```
';

CSV文件数据图形化工具

最后更新于:2022-04-01 11:15:59

# CSV文件数据图形化工具 相关文档: [_设计自己的Trait编辑器_](traitsui_manual_custom_editor.html) 采用 [_在traitsUI中使用的matplotlib控件_](example_mpl_figure_editor.html) 制作的CSV文件数据绘图工具。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1baed4140.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1baeebc0d.png) ``` # -*- coding: utf-8 -*- from matplotlib.figure import Figure from mpl_figure_editor import MPLFigureEditor from enthought.traits.ui.api import * from enthought.traits.api import * import csv class DataSource(HasTraits): """ 数据源,data是一个字典,将字符串映射到列表 names是data中的所有字符串的列表 """ data = DictStrAny names = List(Str) def load_csv(self, filename): """ 从CSV文件读入数据,更新data和names属性 """ f = file(filename) reader = csv.DictReader(f) self.names = reader.fieldnames for field in reader.fieldnames: self.data[field] = [] for line in reader: for k, v in line.iteritems(): self.data[k].append(float(v)) f.close() class Graph(HasTraits): """ 绘图组件,包括左边的数据选择控件和右边的绘图控件 """ name = Str # 绘图名,显示在标签页标题和绘图标题中 data_source = Instance(DataSource) # 保存数据的数据源 figure = Instance(Figure) # 控制绘图控件的Figure对象 selected_xaxis = Str # X轴所用的数据名 selected_items = List # Y轴所用的数据列表 clear_button = Button(u"清除") # 快速清除Y轴的所有选择的数据 view = View( HSplit( # HSplit分为左右两个区域,中间有可调节宽度比例的调节手柄 # 左边为一个组 VGroup( Item("name"), # 绘图名编辑框 Item("clear_button"), # 清除按钮 Heading(u"X轴数据"), # 静态文本 # X轴选择器,用EnumEditor编辑器,即ComboBox控件,控件中的候选数据从 # data_source的names属性得到 Item("selected_xaxis", editor= EnumEditor(name="object.data_source.names", format_str=u"%s")), Heading(u"Y轴数据"), # 静态文本 # Y轴选择器,由于Y轴可以多选,因此用CheckBox列表编辑,按两列显示 Item("selected_items", style="custom", editor=CheckListEditor(name="object.data_source.names", cols=2, format_str=u"%s")), show_border = True, # 显示组的边框 scrollable = True, # 组中的控件过多时,采用滚动条 show_labels = False # 组中的所有控件都不显示标签 ), # 右边绘图控件 Item("figure", editor=MPLFigureEditor(), show_label=False, width=600) ) ) def _name_changed(self): """ 当绘图名发生变化时,更新绘图的标题 """ axe = self.figure.axes[0] axe.set_title(self.name) self.figure.canvas.draw() def _clear_button_fired(self): """ 清除按钮的事件处理 """ self.selected_items = [] self.update() def _figure_default(self): """ figure属性的缺省值,直接创建一个Figure对象 """ figure = Figure() figure.add_axes([0.05, 0.1, 0.9, 0.85]) #添加绘图区域,四周留有边距 return figure def _selected_items_changed(self): """ Y轴数据选择更新 """ self.update() def _selected_xaxis_changed(self): """ X轴数据选择更新 """ self.update() def update(self): """ 重新绘制所有的曲线 """ axe = self.figure.axes[0] axe.clear() try: xdata = self.data_source.data[self.selected_xaxis] except: return for field in self.selected_items: axe.plot(xdata, self.data_source.data[field], label=field) axe.set_xlabel(self.selected_xaxis) axe.set_title(self.name) axe.legend() self.figure.canvas.draw() class CSVGrapher(HasTraits): """ 主界面包括绘图列表,数据源,文件选择器和添加绘图按钮 """ graph_list = List(Instance(Graph)) # 绘图列表 data_source = Instance(DataSource) # 数据源 csv_file_name = File(filter=[u"*.csv"]) # 文件选择 add_graph_button = Button(u"添加绘图") # 添加绘图按钮 view = View( # 整个窗口分为上下两个部分 VGroup( # 上部分横向放置控件,因此用HGroup HGroup( # 文件选择控件 Item("csv_file_name", label=u"选择CSV文件", width=400), # 添加绘图按钮 Item("add_graph_button", show_label=False) ), # 下部分是绘图列表,采用ListEditor编辑器显示 Item("graph_list", style="custom", show_label=False, editor=ListEditor( use_notebook=True, # 是用多标签页格式显示 deletable=True, # 可以删除标签页 dock_style="tab", # 标签dock样式 page_name=".name") # 标题页的文本使用Graph对象的name属性 ) ), resizable = True, height = 0.8, width = 0.8, title = u"CSV数据绘图器" ) def _csv_file_name_changed(self): """ 打开新文件时的处理,根据文件创建一个DataSource """ self.data_source = DataSource() self.data_source.load_csv(self.csv_file_name) del self.graph_list[:] def _add_graph_button_changed(self): """ 添加绘图按钮的事件处理 """ if self.data_source != None: self.graph_list.append( Graph(data_source = self.data_source) ) if __name__ == "__main__": csv_grapher = CSVGrapher() csv_grapher.configure_traits() ```
';

在traitsUI中使用的matplotlib控件

最后更新于:2022-04-01 11:15:57

# 在traitsUI中使用的matplotlib控件 相关文档: [_设计自己的Trait编辑器_](traitsui_manual_custom_editor.html) 在traitsUI所产生的界面中嵌入matplotlib的控件。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1baea865a.png) ``` # -*- coding: utf-8 -*- # file name: mpl_figure_editor.py import wx import matplotlib # matplotlib采用WXAgg为后台,这样才能将绘图控件嵌入以wx为后台界面库的traitsUI窗口中 matplotlib.use("WXAgg") from matplotlib.backends.backend_wxagg import FigureCanvasWxAgg as FigureCanvas from matplotlib.backends.backend_wx import NavigationToolbar2Wx from enthought.traits.ui.wx.editor import Editor from enthought.traits.ui.basic_editor_factory import BasicEditorFactory class _MPLFigureEditor(Editor): """ 相当于wx后台界面库中的编辑器,它负责创建真正的控件 """ scrollable = True def init(self, parent): self.control = self._create_canvas(parent) self.set_tooltip() print dir(self.item) def update_editor(self): pass def _create_canvas(self, parent): """ 创建一个Panel, 布局采用垂直排列的BoxSizer, panel中中添加 FigureCanvas, NavigationToolbar2Wx, StaticText三个控件 FigureCanvas的鼠标移动事件调用mousemoved函数,在StaticText 显示鼠标所在的数据坐标 """ panel = wx.Panel(parent, -1, style=wx.CLIP_CHILDREN) def mousemoved(event): panel.info.SetLabel("%s, %s" % (event.xdata, event.ydata)) panel.mousemoved = mousemoved sizer = wx.BoxSizer(wx.VERTICAL) panel.SetSizer(sizer) mpl_control = FigureCanvas(panel, -1, self.value) mpl_control.mpl_connect("motion_notify_event", mousemoved) toolbar = NavigationToolbar2Wx(mpl_control) sizer.Add(mpl_control, 1, wx.LEFT | wx.TOP | wx.GROW) sizer.Add(toolbar, 0, wx.EXPAND|wx.RIGHT) panel.info = wx.StaticText(parent, -1) sizer.Add(panel.info) self.value.canvas.SetMinSize((10,10)) return panel class MPLFigureEditor(BasicEditorFactory): """ 相当于traits.ui中的EditorFactory,它返回真正创建控件的类 """ klass = _MPLFigureEditor if __name__ == "__main__": from matplotlib.figure import Figure from enthought.traits.api import HasTraits, Instance from enthought.traits.ui.api import View, Item from numpy import sin, cos, linspace, pi class Test(HasTraits): figure = Instance(Figure, ()) view = View( Item("figure", editor=MPLFigureEditor(), show_label=False), width = 400, height = 300, resizable = True) def __init__(self): super(Test, self).__init__() axes = self.figure.add_subplot(111) t = linspace(0, 2*pi, 200) axes.plot(sin(t)) Test().configure_traits() ```
';

三角波的FFT演示

最后更新于:2022-04-01 11:15:55

# 三角波的FFT演示 相关文档: [_FFT演示程序_](fft_study.html) 本程序演示各种三角波形的FFT频谱,用户可以方便地修改三角波的各个参数,并立即看到其FFT频谱的变化。 ![](_images/fft_study_04.swf) ``` # -*- coding: utf-8 -*- from enthought.traits.api import \ Str, Float, HasTraits, Property, cached_property, Range, Instance, on_trait_change, Enum from enthought.chaco.api import Plot, AbstractPlotData, ArrayPlotData, VPlotContainer from enthought.traits.ui.api import \ Item, View, VGroup, HSplit, ScrubberEditor, VSplit from enthought.enable.api import Component, ComponentEditor from enthought.chaco.tools.api import PanTool, ZoomTool import numpy as np # 鼠标拖动修改值的控件的样式 scrubber = ScrubberEditor( hover_color = 0xFFFFFF, active_color = 0xA0CD9E, border_color = 0x808080 ) # 取FFT计算的结果freqs中的前n项进行合成,返回合成结果,计算loops个周期的波形 def fft_combine(freqs, n, loops=1): length = len(freqs) * loops data = np.zeros(length) index = loops * np.arange(0, length, 1.0) / length * (2 * np.pi) for k, p in enumerate(freqs[:n]): if k != 0: p *= 2 # 除去直流成分之外,其余的系数都*2 data += np.real(p) * np.cos(k*index) # 余弦成分的系数为实数部 data -= np.imag(p) * np.sin(k*index) # 正弦成分的系数为负的虚数部 return index, data class TriangleWave(HasTraits): # 指定三角波的最窄和最宽范围,由于Range似乎不能将常数和traits名混用 # 所以定义这两个不变的trait属性 low = Float(0.02) hi = Float(1.0) # 三角波形的宽度 wave_width = Range("low", "hi", 0.5) # 三角波的顶点C的x轴坐标 length_c = Range("low", "wave_width", 0.5) # 三角波的定点的y轴坐标 height_c = Float(1.0) # FFT计算所使用的取样点数,这里用一个Enum类型的属性以供用户从列表中选择 fftsize = Enum( [(2**x) for x in range(6, 12)]) # FFT频谱图的x轴上限值 fft_graph_up_limit = Range(0, 400, 20) # 用于显示FFT的结果 peak_list = Str # 采用多少个频率合成三角波 N = Range(1, 40, 4) # 保存绘图数据的对象 plot_data = Instance(AbstractPlotData) # 绘制波形图的容器 plot_wave = Instance(Component) # 绘制FFT频谱图的容器 plot_fft = Instance(Component) # 包括两个绘图的容器 container = Instance(Component) # 设置用户界面的视图, 注意一定要指定窗口的大小,这样绘图容器才能正常初始化 view = View( HSplit( VSplit( VGroup( Item("wave_width", editor = scrubber, label=u"波形宽度"), Item("length_c", editor = scrubber, label=u"最高点x坐标"), Item("height_c", editor = scrubber, label=u"最高点y坐标"), Item("fft_graph_up_limit", editor = scrubber, label=u"频谱图范围"), Item("fftsize", label=u"FFT点数"), Item("N", label=u"合成波频率数") ), Item("peak_list", style="custom", show_label=False, width=100, height=250) ), VGroup( Item("container", editor=ComponentEditor(size=(600,300)), show_label = False), orientation = "vertical" ) ), resizable = True, width = 800, height = 600, title = u"三角波FFT演示" ) # 创建绘图的辅助函数,创建波形图和频谱图有很多类似的地方,因此单独用一个函数以 # 减少重复代码 def _create_plot(self, data, name, type="line"): p = Plot(self.plot_data) p.plot(data, name=name, title=name, type=type) p.tools.append(PanTool(p)) zoom = ZoomTool(component=p, tool_mode="box", always_on=False) p.overlays.append(zoom) p.title = name return p def __init__(self): # 首先需要调用父类的初始化函数 super(TriangleWave, self).__init__() # 创建绘图数据集,暂时没有数据因此都赋值为空,只是创建几个名字,以供Plot引用 self.plot_data = ArrayPlotData(x=[], y=[], f=[], p=[], x2=[], y2=[]) # 创建一个垂直排列的绘图容器,它将频谱图和波形图上下排列 self.container = VPlotContainer() # 创建波形图,波形图绘制两条曲线: 原始波形(x,y)和合成波形(x2,y2) self.plot_wave = self._create_plot(("x","y"), "Triangle Wave") self.plot_wave.plot(("x2","y2"), color="red") # 创建频谱图,使用数据集中的f和p self.plot_fft = self._create_plot(("f","p"), "FFT", type="scatter") # 将两个绘图容器添加到垂直容器中 self.container.add( self.plot_wave ) self.container.add( self.plot_fft ) # 设置 self.plot_wave.x_axis.title = "Samples" self.plot_fft.x_axis.title = "Frequency pins" self.plot_fft.y_axis.title = "(dB)" # 改变fftsize为1024,因为Enum的默认缺省值为枚举列表中的第一个值 self.fftsize = 1024 # FFT频谱图的x轴上限值的改变事件处理函数,将最新的值赋值给频谱图的响应属性 def _fft_graph_up_limit_changed(self): self.plot_fft.x_axis.mapper.range.high = self.fft_graph_up_limit def _N_changed(self): self.plot_sin_combine() # 多个trait属性的改变事件处理函数相同时,可以用@on_trait_change指定 @on_trait_change("wave_width, length_c, height_c, fftsize") def update_plot(self): # 计算三角波 global y_data x_data = np.arange(0, 1.0, 1.0/self.fftsize) func = self.triangle_func() # 将func函数的返回值强制转换成float64 y_data = np.cast["float64"](func(x_data)) # 计算频谱 fft_parameters = np.fft.fft(y_data) / len(y_data) # 计算各个频率的振幅 fft_data = np.clip(20*np.log10(np.abs(fft_parameters))[:self.fftsize/2+1], -120, 120) # 将计算的结果写进数据集 self.plot_data.set_data("x", np.arange(0, self.fftsize)) # x坐标为取样点 self.plot_data.set_data("y", y_data) self.plot_data.set_data("f", np.arange(0, len(fft_data))) # x坐标为频率编号 self.plot_data.set_data("p", fft_data) # 合成波的x坐标为取样点,显示2个周期 self.plot_data.set_data("x2", np.arange(0, 2*self.fftsize)) # 更新频谱图x轴上限 self._fft_graph_up_limit_changed() # 将振幅大于-80dB的频率输出 peak_index = (fft_data > -80) peak_value = fft_data[peak_index][:20] result = [] for f, v in zip(np.flatnonzero(peak_index), peak_value): result.append("%s : %s" %(f, v) ) self.peak_list = "\n".join(result) # 保存现在的fft计算结果,并计算正弦合成波 self.fft_parameters = fft_parameters self.plot_sin_combine() # 计算正弦合成波,计算2个周期 def plot_sin_combine(self): index, data = fft_combine(self.fft_parameters, self.N, 2) self.plot_data.set_data("y2", data) # 返回一个ufunc计算指定参数的三角波 def triangle_func(self): c = self.wave_width c0 = self.length_c hc = self.height_c def trifunc(x): x = x - int(x) # 三角波的周期为1,因此只取x坐标的小数部分进行计算 if x >= c: r = 0.0 elif x < c0: r = x / c0 * hc else: r = (c-x) / (c-c0) * hc return r # 用trifunc函数创建一个ufunc函数,可以直接对数组进行计算, 不过通过此函数 # 计算得到的是一个Object数组,需要进行类型转换 return np.frompyfunc(trifunc, 1, 1) if __name__ == "__main__": triangle = TriangleWave() triangle.configure_traits() ```
';

源程序集

最后更新于:2022-04-01 11:15:53

# 源程序集 * [三角波的FFT演示](example_fft_triangle_GUI.html) * [在traitsUI中使用的matplotlib控件](example_mpl_figure_editor.html) * [CSV文件数据图形化工具](example_traitsUI_csv_viewer.html) * [NLMS算法的模拟测试](example_nlms_test.html) * [三维标量场观察器](example_mayavi_embed_fieldviewer.html) * [频谱泄漏和hann窗](example_spectrum_example_hann.html) * [FFT卷积的速度比较](example_spectrum_fft_convolve_timeit.html) * [二次均衡器设计](example_equalizer_designer.html) * [单摆摆动周期的计算](example_simple_pendulum_period.html) * [双摆系统的动画模拟](example_double_pendulum.html) * [绘制Mandelbrot集合](example_mandelbrot.html) * [迭代函数系统的分形](example_ifs.html) * [绘制L-System的分形图](example_lsystem.html)
';

最近更新

最后更新于:2022-04-01 11:15:50

# 最近更新 * 2010/01/15: [_将Mayavi嵌入到界面中_](mlab_and_mayavi.html#sec-mayavi-embed) * 2010/01/14: [_模拟IIR滤波器的频带转换_](filters.html#sec-iirbandtrans) * 2010/01/12: 修改Sphinx模板,添加支持中文搜索的插件,中文分词库采用 > smallseg: [http://code.google.com/p/smallseg](http://code.google.com/p/smallseg) * 2010/01/07: [_巴特沃斯低通滤波器_](filters.html#sec-filter-butter) ; [_双线性变换_](filters.html#sec-filter-bilinear) * 2010/01/05: [_用Sympy计算球体体积_](sympy_intro.html#sec-sympy-sphere) ; [_NumPy-快速处理数据_](numpy_intro.html) 添加少许新内容;修改章节名 * 2010/01/04: [_L-System分形_](fractal_chaos.html#sec-lsystem) * 2010/01/03: [_迭代函数系统设计器_](fractal_chaos.html#sec-ifs-designer) * 2010/01/02: [_迭代函数系统(IFS)_](fractal_chaos.html#sec-ifs) * 2009/12/30 : [_Matplotlib的Axis对象_](matplotlib_intro.html#sec-matplotlib-axis) * 2009/12/29 : [_绘制Mandelbrot集合_](fractal_chaos.html#sec-mandelbrot)
';

关于本书的编写

最后更新于:2022-04-01 11:15:48

# 关于本书的编写 为了编写此书,我评价了许多写书的软件,最终决定使用Sphinx和reStructuredText作为写书的工具。随着章节的逐渐增加,我越来越觉得当初的选择是正确的。 ## 本书的编写工具 本书采用[reStructuredText](http://docutils.sourceforge.net/rst.html)(rst) 格式的文本编写,然后用[Sphinx](http://sphinx.pocoo.org)将reStructuredText文件自动转换为html格式的文件。采用[Leo](http://webpages.charter.net/edreamleo/front.html)管理和组织所有的文档。用[proTeXt](http://www.tug.org/protext)将latex格式的数学公式转换为PNG图片。 * **reStructuredText** : 一种结构化文本格式,它提供了对写书所需的各种元素的支持。例如章节、链接、格式、图片以及语法高亮等等。 * **Sphinx** : 将一系列reStructuredText文本转换成各种不同的输出格式,并自动制作交叉引用(cross-references)、索引等。也就是说,如果某目录中有一系列的reStructuredText格式的文档, Sphinx可以制作一份组织得非常完美的HTML文件。 * **Leo** : 以树状结构管理文本、代码的编辑器,可以用它来进行数据组织和项目管理。我使用它管理构成本书的所有rst文档、python程序以及图片和笔记。下面是使用Leo编写本书时的一个例子: > ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbc096c6.png) > > 编写本书所使用的Leo编辑器的界面 * **PicPick**, **Greenshot** : 界面截图工具。 ## 问题与解决方案 在使用上述工具编写本书时,为了达到完美的效果,我对这些工具做了一些配置和修改的工作。 ### 代码中的注释 Sphinx使用Pygments进行代码高亮的处理,在Pygments的缺省样式中,代码注释部分是采用斜体字表示的,斜体的汉字看起来十分别扭,因此需要将缺省样式的斜体改为正体。在conf.py文件中有如下配置: ``` # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' ``` 它指定pygments使用sphinx样式对代码进行高亮处理,我没有弄明白如何添加自己定义的样式,因此直接手工修改定义此样式的文件: ``` %Python安装目录%\Lib\site-packages\sphinx\highlighting.py ``` 将其中的Comment的样式改为noitalic: ``` ... styles.update({ Generic.Output: '#333', Comment: 'noitalic #408090', Number: '#208050', }) ... ``` ### 修改Sphinx的主题 为了给文档添加评论功能,必须添加一部分javascript代码,因此需要修改Shpinx的主题。 * 首先编辑conf.py文件中如下的两个配置: ``` # The theme to use for HTML and HTML Help pages. Major themes that come with # Sphinx are currently 'default' and 'sphinxdoc'. html_theme = 'pydoc' # Add any paths that contain custom themes here, relative to this directory. html_theme_path = ["./theme"] ``` * 然后在conf.py文件所在的目录下创建一个子目录theme,将sphinx安装目录下的themes\sphinxdoc文件夹复制到theme文件夹下,并重命名为pydoc,目录结构如下图所示: > ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbc29188.png) > > theme文件夹的结构 * 编辑layout.html文件。此文件是一个模板,Sphinx最终使用此模板生成每个rst文件所对应的html文件。因此我在其中添加了对我自己的css和js文件的引用: ``` <link type="text/css" href="_static/jquery-ui-1.7.2.custom.css" rel="stylesheet" /> <link type="text/css" href="_static/comments.css" rel="stylesheet" /> <script type="text/javascript" src="_static/jquery-ui-1.7.2.custom.min.js"></script> <script type="text/javascript" src="_static/pydoc.js"></script> ``` * 在theme\pydoc\static目录下添加相应的css和js文件。为了固定html页面左侧的目录栏,可以配置theme\pydoc\theme.conf中的stickysidebar=True,不过好像IE7.0下无法正常显示,因此在css文件中添加如下代码,除了IE6.0以外其它的浏览器(Firefox,IE7, Chrome)都能够正常固定目录栏: > ``` > div.sphinxsidebar{ > position : fixed; > left : 0px; > top : 30px; > margin-left : 0px !important; > } > > ``` ### 关闭引号自动转换 在输出html的时候,如果使用Sphinx缺省的配置,会对引号进行自动转换:将标准的单引号和双引号转换为unicode中的全角引号。为了关闭此项功能,需要编辑 conf.py,进行如下设置: ``` html_use_smartypants = False ``` ### 用latex编写数学公式 Sphinx支持将latex编写的数学公式转换为png图片。为了在windows下使用latex,我下载了[proTeXt](http://www.tug.org/protext),这个tex软件包的大小有700M左右,安装之后占用1.3G。为了告诉Sphinx工具latex的安装位置,如下修改make.bat文件: ``` %SPHINXBUILD% -D pngmath_latex="..\latex.exe" -b html %ALLSPHINXOPTS% build/html ``` 然后就可以如下使用latex: ``` X_k = \sum_{n=0}^{N-1} x_n e^{-{i 2\pi k \frac{n}{N}}} \qquad k = 0,\dots,N-1. ``` 得到的输出图片如下: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbc3947e.png) ### Leo的配置 Leo的缺省配置用起来很不习惯:它的树状目录在上方,而且字体很小。下面是对Leo的一些修改和配置: * Leo现在可以使用tk和qt两个库。使用tk库的界面用起来不习惯,因此通过在启动Leo时添加参数强制使用qt库的界面:launchLeo.py --gui=qt 。 * 我个人很喜欢微软雅黑的汉字字体,但是由于雅黑字体的英文不是等宽的,因此用它来编辑代码的话就很不爽了。于是上网找到了一个雅黑和Consolas的复合字体: > YaHei Mono字体下载地址: [http://hyry.dip.jp/files/yahei_mono.7z](http://hyry.dip.jp/files/yahei_mono.7z) * 复制一份leo\config\leoSettings.leo到同一目录,改名为myLeoSettings.leo。用Leo编辑此文件,在目录树中找到节点:qtGui plugin--&gt;@data qt-gui-plugin-style-sheet,修改此样式表中的字体的定义,使用新安装的Yahei Mono字体。 > ``` > QTextEdit#richTextEdit { > ... > font-family: Yahei Mono; > font-size: 17px; > ... > } > > ``` * 修改@settings--&gt;Window--&gt;@string initial_split_orientation节点和@settings--&gt;Window--&gt;Options for new windows--&gt;@strings[vertical,horizontal] initial_splitter_orientation节点的值为horizontal。这样目录树和编辑框就是左右分栏的了。 * 在Leo中用@auto-rst输出rst文件时,会自动的将目录树中的节点名转换为rst文件中的标题。在rst中标题都是由下划线标出的。下划线的长度要求和文本的长度一致。由于Leo采用unicode表示文本,因此汉字的长度为1,但是rst编译器似乎要求汉字的长度为2,因此对于 **Leo的配置** 这样的标题,rst要求用9个下划线符号标识,而Leo只用6个,造成在编译时出现许多警告信息,为了解决这个问题,编辑leo\core\leoRst.py文件中的underline函数如下,并且将其后的所有len(s)都改为len(ss): ``` def underline (self,s,p): ... try: ss = s.encode("gbk") except: try: ss = s.encode("shiftjis") except: ss = s trace = False and not g.unitTesting ... ``` ### 让Matplotlib显示中文 将中文字体文件复制到: ``` %PythonPath%\Lib\site-packages\matplotlib\mpl-data\fonts\ttf\ ``` 下,这里以上一节介绍的Yahei Mono字体为例。 找到Matplotlib的配置文件matplotlibrc,全局配置文件的路径: ``` %PythonPath%\Lib\site-packages\matplotlib\mpl-data\matplotlibrc ``` 用户配置文件路径: ``` c:\Documents and Settings\%UserName%\.matplotlib\matplotlibrc ``` 用文本编辑器打开此文件,进行如下编辑: * 找到设置font.family的行,改为font.family : monospace,注意去掉前面的#号。 * 在下面添加一行:font.monospace : Yahei Mono。 在matplotlib中使用中文字符串时记住要用unicode格式,例如:u"测试中文显示"。 ### 用Matplotlib生成图片 matplotlib提供了一个Sphinx的扩展插件,可以使用..plot命令自动生成图片。可是这个插件生成的图片的路径和本书所采用的路径不符合,无法在HTML文件中显示最终生成的图。因此我直接复制下面两个文件: ``` c:\Python26\Lib\site-packages\matplotlib\sphinxext\plot_directive.py c:\Python26\Lib\site-packages\matplotlib\sphinxext\only_directives.py ``` 到sourceexts下,命名为plot_directive.py。然后编辑conf.py,修改下面两行: ``` sys.path.append(os.path.abspath('exts')) extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.pngmath', 'plot_directive'] ``` 这样就可以载入extsplot_directive.py扩展插件了。然后编辑plot_directive.py文件,使得它的输出符合本书的路径,并且除去大图和PDF输出。 在rst文件中使用: ``` import matplotlib.pyplot as plt import numpy as np x = np.random.randn(1000) plt.hist( x, 20) plt.grid() plt.title(r'Normal: $\mu=%.2f, \sigma=%.2f$'%(x.mean(), x.std())) plt.show() ``` ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbc49a9a.png) ### 用Graphviz绘图 Sphinx可以调用Graphviz绘制流程图,首先下载Graphviz的Windows安装包进行安装,假设安装目录为c:\graphviz。 Graphviz的下载地址: [http://www.graphviz.org](http://www.graphviz.org) 编辑conf.py配置文件,在 extensions 列表定义的最后添加一项:'sphinx.ext.graphviz'。 如下编辑make.bat文件,配置dot.exe的执行路径: ``` .. graphviz:: digraph GraphvizDemo{ node [fontname="Yahei Mono" shape="rect"]; edge [fontname="Yahei Mono" fontsize=10]; node1[label="开始"]; node2[label="正常"]; node1->node2[label="测试"]; } ``` 输出图为: ![digraph GraphvizDemo{ node [fontname="Yahei Mono" shape="rect"]; edge [fontname="Yahei Mono" fontsize=10]; node1[label="开始"]; node2[label="正常"]; node1-&gt;node2[label="测试"]; }](img/graphviz-691597b9de6125817b93aaad942bf30f1e3d5346.png) ### 制作CHM文档 Sphinx支持输出为CHM文档格式,只需要运行make htmlhelp即可。但是此命令输出的目录文件(扩展名为.hhc),却不支持中文。为了解决这个问题,我进行了如下修改: * sphinx的安装目录下找到buildershtmlhelp.py,将其复制一份,改名为htmlhelpcn.py。输出CHM文档的程序都在这里面。 * 修改builders\_\_init\_\_.py文件,在其最后的BUILTIN_BUILDERS字典定义中添加一行: > ``` > 'htmlhelpcn': ('htmlhelpcn', 'HTMLHelpBuilder') > > ``` * 修改make.bat文件,在其中添加: > ``` > if "%1" == "htmlhelpcn" ( > %SPHINXBUILD% -b htmlhelpcn %ALLSPHINXOPTS% build/htmlhelpcn > echo. > echo.Build finished; now you can run HTML Help Workshop with the ^ > .hhp project file in build/htmlhelpcn. > goto end > ) > > ``` * 编辑htmlhelpcn.py文件,找到project_template字符串的定义,修改其中的Language定义为Language=0x804。 * 反复运行make.bat htmlhelpcn命令,根据输出的错误提示修改htmlhelpcn.py,将其中几处编码错误的地方都添加.encode("gb2312")。其中有一处: > ``` > f.write(item.encode('ascii', 'xmlcharrefreplace')) > > # 改为--&gt; > > f.write(item.encode('gb2312')) > > ``` * 如果在rst文档中给图片添加了中文说明的话,有可能输出的CHM文件中看不到图片。 * make.bat htmlhelpcn正常运行之后,运行下面的命令输出制作CHM文件: > ``` > "C:\Program Files\HTML Help Workshop\hhc.exe" htmlhelpcn\scipydoc.hhp > > ``` ### CHM中嵌入Flash动画 用如下的reStructuredText的 raw 指令可以在html中嵌入Flash动画: ``` <OBJECT CLASSID="clsid:D27CDB6E-AE6D-11cf-96B8-444553540000" WIDTH="589" HEIGHT="447" CODEBASE="http://active.macromedia.com/flash5/cabs/swflash.cab#version=7,0,0,0"> <PARAM NAME="movie" VALUE="img/fft_study_04.swf"> <PARAM NAME="play" VALUE="true"> <PARAM NAME="loop" VALUE="false"> <PARAM NAME="wmode" VALUE="transparent"> <PARAM NAME="quality" VALUE="high"> <EMBED SRC="img/fft_study_04.swf" width="589" HEIGHT="447" quality="high" loop="false" wmode="transparent" TYPE="application/x-shockwave-flash" PLUGINSPAGE= "http://www.macromedia.com/shockwave/download/index.cgi?P1_Prod_Version=ShockwaveFlash"> </EMBED> </OBJECT> ``` 由于Html Help Workshop不会将swf文件打包进CHM,因此CHM中看不到flash动画,只需要在嵌入flash动画的html之后添加一条: ``` <img src="img/fft_study_04.swf" style="visibility:hidden"/> ``` 这样Html Help Workshop就会把fft_study_04.swf文件添加进去,由于使用隐藏的CSS,页面中也不会把它当作图片显示出来。 ### 制作PDF文档 调用make latex命令可以输出为latex格式的文件,然后调用 xelatex scipydoc.tex 即可将其转换为PDF文件,xelatex是proTeXt自带的命令。制作PDF文档时同样有中文无法显示的问题,按照以下步骤解决: * 编辑文档的配置文件conf.py,在最后的 Options for LaTeX output 定义处,添加如下代码,这段文字将添加到最终输出的tex文件中,这里的Yahei Mono可以修改为你想要的中文字体名: ``` latex_preamble = r""" \usepackage{float} \textwidth 6.5in \oddsidemargin -0.2in \evensidemargin -0.2in \usepackage{ccaption} \usepackage{fontspec,xunicode,xltxtra} \setsansfont{Microsoft YaHei} \setromanfont{Microsoft YaHei} \setmainfont{Microsoft YaHei} \setmonofont{Yahei Mono} \XeTeXlinebreaklocale "zh" \XeTeXlinebreakskip = 0pt plus 1pt \renewcommand{\baselinestretch}{1.3} \setcounter{tocdepth}{3} \captiontitlefont{\small\sffamily} \captiondelim{ - } \renewcommand\today{\number\year年\number\month月\number\day日} \makeatletter \renewcommand*\l@subsection{\@dottedtocline{2}{2.0em}{4.0em}} \renewcommand*\l@subsubsection{\@dottedtocline{3}{3em}{5em}} \makeatother \titleformat{\chapter}[display] {\bfseries\Huge} {\filleft \Huge 第 \hspace{2 mm} \thechapter \hspace{4 mm} 章} {4ex} {\titlerule \vspace{2ex}% \filright} [\vspace{2ex}% \titlerule] %\definecolor{VerbatimBorderColor}{rgb}{0.2,0.2,0.2} \definecolor{VerbatimColor}{rgb}{0.95,0.95,0.95} """.decode("utf-8") ``` 通过renewcommand命令将输出的PDF文档中的一部分英文修改为中文。 不知何故,在latex_preamble中添加修改插图标题前缀的命令没有作用,因此通过下面的命令在正文中添加转换前缀的renewcommand: ``` .. raw:: latex \renewcommand\partname{部分} \renewcommand{\chaptermark}[1]{\markboth{第 \thechapter\ 章 \hspace{4mm} #1}{}} \fancyhead[LE,RO]{用Python做科学计算} \renewcommand{\figurename}{\textsc{图}} ``` * 调整conf.py中的其它选项: > ``` > latex_paper_size = 'a4' > latex_font_size = '11pt' > latex_use_modindex = False > > ``` * 运行下面的命令输出PDF文档,使用nonstopmode,即使出现错误也不暂停运行。 > ``` > xelatex -interaction=nonstopmode scipydoc.tex > > ``` 还有一些latex配置没有找到如何使用reStructuredText进行设置,因此写了一个Python的小程序读取输出的tex文件,替换其中的一些latex命令: * 将begin{figure}[htbp]改为begin{figure}[H},这样能保证图和文字保持tex中的前后关系,而不会对图进行自动排版 * 在\tableofcontents之前添加\renewcommand\contentsname{目 录},将目录标题的英文改为中文,此段配置在latex_preamble中定义无效 ### 添加PDF封面 使用作图软件设计封面图片之后,使用图片转PDF工具将其转换为一个只有一页的PDF文档cover.pdf: 图片转PDF工具下载地址: [http://www.softinterface.com](http://www.softinterface.com) 然后使用PDF合并工具将cover.pdf和正文的PDF文件进行合并。我在网络上找了很久,终于找到了下面这个能够维持内部链接和书签的免费的合并工具: PDF工具PDFsam下载地址: [http://www.pdfsam.org](http://www.pdfsam.org) PDFsam提供了界面和命令行方式,界面方式很容易使用,但是为了一个批处理产生最终PDF文档我需要使用命令行方式,下面是使用命令行进行PDF文档合并的批处理程序: ``` set MERGE=java -jar "c:\Program Files\pdfsam\lib\pdfsam-console-2.2.0e.jar" %MERGE% -f cover.pdf -f scipydoc.pdf -o %CD%\scipydoc2.pdf concat ``` * -f参数指定输入的PDF文件名 * -o参数指定输出的PDF文件名,注意必须使用绝对路径,因此这里使用%CD%将相对路径转换为绝对路径。 ### 输出打包的批处理 下面是同时输出zip, chm, pdf文件的批处理命令: ``` rename html scipydoc "c:\Program Files\7-Zip\7z.exe" a scipydoc.zip scipydoc rename scipydoc html "C:\Program Files\HTML Help Workshop\hhc.exe" htmlhelpcn\scipydoc.hhp copy htmlhelpcn\scipydoc.chm . /y cd latex xelatex -interaction=nonstopmode scipydoc.tex cd .. copy latex\scipydoc.pdf . /y ``` ### HTML的中文搜索 由于Sphinx不懂中文分词,因此它所生成的搜索索引文件searchindex.js中的中文单词分的不正确。为了修正这个问题,我写了一个Sphinx扩展chinese_search.py,使用中文分词库smallseg生成索引文件中的中文单词。 smallseg中文分词库地址: [http://code.google.com/p/smallseg](http://code.google.com/p/smallseg) 下面是这个扩展的完整源程序: ``` from os import path import re import cPickle as pickle from docutils.nodes import comment, Text, NodeVisitor, SkipNode from sphinx.util.stemmer import PorterStemmer from sphinx.util import jsdump, rpartition from smallseg import SEG DEBUG = False word_re = re.compile(r'\w+(?u)') stopwords = set(""" a and are as at be but by for if in into is it near no not of on or such that the their then there these they this to was will with """.split()) if DEBUG: testfile = file("testfile.txt", "wb") class _JavaScriptIndex(object): """ The search index as javascript file that calls a function on the documentation search object to register the index. """ PREFIX = 'Search.setIndex(' SUFFIX = ')' def dumps(self, data): return self.PREFIX + jsdump.dumps(data) + self.SUFFIX def loads(self, s): data = s[len(self.PREFIX):-len(self.SUFFIX)] if not data or not s.startswith(self.PREFIX) or not \ s.endswith(self.SUFFIX): raise ValueError('invalid data') return jsdump.loads(data) def dump(self, data, f): f.write(self.dumps(data)) def load(self, f): return self.loads(f.read()) js_index = _JavaScriptIndex() class Stemmer(PorterStemmer): """ All those porter stemmer implementations look hideous. make at least the stem method nicer. """ def stem(self, word): word = word.lower() return word #return PorterStemmer.stem(self, word, 0, len(word) - 1) class WordCollector(NodeVisitor): """ A special visitor that collects words for the `IndexBuilder`. """ def __init__(self, document): NodeVisitor.__init__(self, document) self.found_words = [] def dispatch_visit(self, node): if node.__class__ is comment: raise SkipNode if node.__class__ is Text: words = seg.cut(node.astext().encode("utf8")) words.reverse() self.found_words.extend(words) class IndexBuilder(object): """ Helper class that creates a searchindex based on the doctrees passed to the `feed` method. """ formats = { 'jsdump': jsdump, 'pickle': pickle } def __init__(self, env): self.env = env self._stemmer = Stemmer() # filename -> title self._titles = {} # stemmed word -> set(filenames) self._mapping = {} # desctypes -> index self._desctypes = {} def load(self, stream, format): """Reconstruct from frozen data.""" if isinstance(format, basestring): format = self.formats[format] frozen = format.load(stream) # if an old index is present, we treat it as not existing. if not isinstance(frozen, dict): raise ValueError('old format') index2fn = frozen['filenames'] self._titles = dict(zip(index2fn, frozen['titles'])) self._mapping = {} for k, v in frozen['terms'].iteritems(): if isinstance(v, int): self._mapping[k] = set([index2fn[v]]) else: self._mapping[k] = set(index2fn[i] for i in v) # no need to load keywords/desctypes def dump(self, stream, format): """Dump the frozen index to a stream.""" if isinstance(format, basestring): format = self.formats[format] format.dump(self.freeze(), stream) def get_modules(self, fn2index): rv = {} for name, (doc, _, _, _) in self.env.modules.iteritems(): if doc in fn2index: rv[name] = fn2index[doc] return rv def get_descrefs(self, fn2index): rv = {} dt = self._desctypes for fullname, (doc, desctype) in self.env.descrefs.iteritems(): if doc not in fn2index: continue prefix, name = rpartition(fullname, '.') pdict = rv.setdefault(prefix, {}) try: i = dt[desctype] except KeyError: i = len(dt) dt[desctype] = i pdict[name] = (fn2index[doc], i) return rv def get_terms(self, fn2index): rv = {} for k, v in self._mapping.iteritems(): if len(v) == 1: fn, = v if fn in fn2index: rv[k] = fn2index[fn] else: rv[k] = [fn2index[fn] for fn in v if fn in fn2index] return rv def freeze(self): """Create a usable data structure for serializing.""" filenames = self._titles.keys() titles = self._titles.values() fn2index = dict((f, i) for (i, f) in enumerate(filenames)) return dict( filenames=filenames, titles=titles, terms=self.get_terms(fn2index), descrefs=self.get_descrefs(fn2index), modules=self.get_modules(fn2index), desctypes=dict((v, k) for (k, v) in self._desctypes.items()), ) def prune(self, filenames): """Remove data for all filenames not in the list.""" new_titles = {} for filename in filenames: if filename in self._titles: new_titles[filename] = self._titles[filename] self._titles = new_titles for wordnames in self._mapping.itervalues(): wordnames.intersection_update(filenames) def feed(self, filename, title, doctree): """Feed a doctree to the index.""" self._titles[filename] = title visitor = WordCollector(doctree) doctree.walk(visitor) def add_term(word, prefix='', stem=self._stemmer.stem): word = stem(word) word = word.strip(u"!@#$%^&*()_+-*/\\\";,.[]{}<>") if len(word) <= 1: return if word.encode("utf8").isalpha() and len(word) < 3: return if word.isdigit(): return if word in stopwords: return try: float(word) return except: pass if DEBUG: testfile.write("%s\n" % word.encode("utf8")) self._mapping.setdefault(prefix + word, set()).add(filename) words = seg.cut(title.encode("utf8")) for word in words: add_term(word) for word in visitor.found_words: add_term(word) def load_indexer(self): def func(docnames): print "############### CHINESE INDEXER ###############" self.indexer = IndexBuilder(self.env) keep = set(self.env.all_docs) - set(docnames) try: f = open(path.join(self.outdir, self.searchindex_filename), 'rb') try: self.indexer.load(f, self.indexer_format) finally: f.close() except (IOError, OSError, ValueError): if keep: self.warn('search index couldn\'t be loaded, but not all ' 'documents will be built: the index will be ' 'incomplete.') # delete all entries for files that will be rebuilt self.indexer.prune(keep) return func def builder_inited(app): if app.builder.name == 'html': print "****************************" global seg seg = SEG() app.builder.load_indexer = load_indexer(app.builder) def setup(app): app.connect('builder-inited', builder_inited) ``` ### PDF的页码和图编号参照 Sphinx生成的tex文件没有使用\label和\ref进行编号引用,而是生成一些链接,这些链接虽然方便电子版的阅读,可是打印出来之后就毫无用处了,因此我写了一个扩展latex_ref.py为最终生成的PDF添加编号引用功能,这个扩展添加了三个role:tlabel, tref, tpageref,分别对应tex的\label, \ref, \pageref。 下面是完整的源程序: ``` # -*- coding: utf-8 -*- from docutils import nodes, utils class tref(nodes.Inline, nodes.TextElement): pass class tlabel(nodes.Inline, nodes.TextElement): pass class tpageref(nodes.Inline, nodes.TextElement): pass def tref_role(role, rawtext, text, lineno, inliner, options={}, content=[]): data = text.split(",") if u"图" in data[0]: name = u"图" pos = data[0][0] ref = data[1] return [tref(name=name, ref=ref, pos=pos)], [] return [],[] def tlabel_role(role, rawtext, text, lineno, inliner, options={}, content=[]): return [tlabel(latex=text)], [] def tpageref_role(role, rawtext, text, lineno, inliner, options={}, content=[]): return [tpageref(latex=text)], [] def latex_visit_ref(self, node): self.body.append(r"%s\ref{%s}" % (node['name'], node['ref'])) raise nodes.SkipNode def html_visit_ref(self, node): self.body.append(r'<a href="#%s">%s%s</a>' % (node['ref'], node['pos'], node['name'])) raise nodes.SkipNode def latex_visit_label(self, node): self.body.append(r"\label{%s}" % node['latex']) raise nodes.SkipNode def latex_visit_pageref(self, node): self.body.append(r"\pageref{%s}" % node['latex']) raise nodes.SkipNode def empty_visit(self, node): raise nodes.SkipNode def setup(app): app.add_node(tref,latex=(latex_visit_ref, None),text=(empty_visit, None),html=(html_visit_ref, None)) app.add_node(tlabel,latex=(latex_visit_label, None),text=(empty_visit, None),html=(empty_visit, None)) app.add_node(tpageref,latex=(latex_visit_pageref, None),text=(empty_visit, None),html=(empty_visit, None)) app.add_role('tref', tref_role) app.add_role('tlabel', tlabel_role) app.add_role('tpageref', tpageref_role) ``` ## ReST使用心得 ### 添加图的编号和标题 使用figure命令插入带编号和标题的插图: ``` .. _pythonxyhome: .. figure:: images/pythonxy_home.png Python(x,y)的启动画面 ``` ### PDF文字包围图片 当给figure添加figwidth和align属性之后,在生成的latex文档中,将使用wrapfigure生成图。为了和前面的段落之间添加一个换行符,使用一个斜杠空格。 ``` .. literalinclude:: examples/tvtk_cone.example.py .. literalinclude:: example.c :language: c ``` ## 未解决的问题 **数学公式输出不正确** 有时候数学公式的输出不正确,某些数学符号不能显示,可是多试几次之后就正常了,不知道是什么原因。 **Leo不能配置目录树和编辑框的宽度比例** 每次Leo开启之后目录树和编辑框的宽度是相等的,看上去很不协调。而且修改mySettings.leo中的相关配置也不能解决,不明白是什么问题。目前的解决方法是添加两个工具按钮:show-tree和hide-tree,这样点击一下show-tree就会将目录树和编辑框改为1:3的比例;而点击hide-tree则能隐藏目录树: ``` # -*- coding: utf-8 -*- from enthought.traits.api import \ Str, Float, HasTraits, Property, cached_property, Range, Instance, on_trait_change, Enum from enthought.chaco.api import Plot, AbstractPlotData, ArrayPlotData, VPlotContainer from enthought.traits.ui.api import \ Item, View, VGroup, HSplit, ScrubberEditor, VSplit from enthought.enable.api import Component, ComponentEditor from enthought.chaco.tools.api import PanTool, ZoomTool import numpy as np # 鼠标拖动修改值的控件的样式 scrubber = ScrubberEditor( hover_color = 0xFFFFFF, active_color = 0xA0CD9E, border_color = 0x808080 ) # 取FFT计算的结果freqs中的前n项进行合成,返回合成结果,计算loops个周期的波形 def fft_combine(freqs, n, loops=1): length = len(freqs) * loops data = np.zeros(length) index = loops * np.arange(0, length, 1.0) / length * (2 * np.pi) for k, p in enumerate(freqs[:n]): if k != 0: p *= 2 # 除去直流成分之外,其余的系数都*2 data += np.real(p) * np.cos(k*index) # 余弦成分的系数为实数部 data -= np.imag(p) * np.sin(k*index) # 正弦成分的系数为负的虚数部 return index, data class TriangleWave(HasTraits): # 指定三角波的最窄和最宽范围,由于Range似乎不能将常数和traits名混用 # 所以定义这两个不变的trait属性 low = Float(0.02) hi = Float(1.0) # 三角波形的宽度 wave_width = Range("low", "hi", 0.5) # 三角波的顶点C的x轴坐标 length_c = Range("low", "wave_width", 0.5) # 三角波的定点的y轴坐标 height_c = Float(1.0) # FFT计算所使用的取样点数,这里用一个Enum类型的属性以供用户从列表中选择 fftsize = Enum( [(2**x) for x in range(6, 12)]) # FFT频谱图的x轴上限值 fft_graph_up_limit = Range(0, 400, 20) # 用于显示FFT的结果 peak_list = Str # 采用多少个频率合成三角波 N = Range(1, 40, 4) # 保存绘图数据的对象 plot_data = Instance(AbstractPlotData) # 绘制波形图的容器 plot_wave = Instance(Component) # 绘制FFT频谱图的容器 plot_fft = Instance(Component) # 包括两个绘图的容器 container = Instance(Component) # 设置用户界面的视图, 注意一定要指定窗口的大小,这样绘图容器才能正常初始化 view = View( HSplit( VSplit( VGroup( Item("wave_width", editor = scrubber, label=u"波形宽度"), Item("length_c", editor = scrubber, label=u"最高点x坐标"), Item("height_c", editor = scrubber, label=u"最高点y坐标"), Item("fft_graph_up_limit", editor = scrubber, label=u"频谱图范围"), Item("fftsize", label=u"FFT点数"), Item("N", label=u"合成波频率数") ), Item("peak_list", style="custom", show_label=False, width=100, height=250) ), VGroup( Item("container", editor=ComponentEditor(size=(600,300)), show_label = False), orientation = "vertical" ) ), resizable = True, width = 800, height = 600, title = u"三角波FFT演示" ) # 创建绘图的辅助函数,创建波形图和频谱图有很多类似的地方,因此单独用一个函数以 # 减少重复代码 def _create_plot(self, data, name, type="line"): p = Plot(self.plot_data) p.plot(data, name=name, title=name, type=type) p.tools.append(PanTool(p)) zoom = ZoomTool(component=p, tool_mode="box", always_on=False) p.overlays.append(zoom) p.title = name return p def __init__(self): # 首先需要调用父类的初始化函数 super(TriangleWave, self).__init__() # 创建绘图数据集,暂时没有数据因此都赋值为空,只是创建几个名字,以供Plot引用 self.plot_data = ArrayPlotData(x=[], y=[], f=[], p=[], x2=[], y2=[]) # 创建一个垂直排列的绘图容器,它将频谱图和波形图上下排列 self.container = VPlotContainer() # 创建波形图,波形图绘制两条曲线: 原始波形(x,y)和合成波形(x2,y2) self.plot_wave = self._create_plot(("x","y"), "Triangle Wave") self.plot_wave.plot(("x2","y2"), color="red") # 创建频谱图,使用数据集中的f和p self.plot_fft = self._create_plot(("f","p"), "FFT", type="scatter") # 将两个绘图容器添加到垂直容器中 self.container.add( self.plot_wave ) self.container.add( self.plot_fft ) # 设置 self.plot_wave.x_axis.title = "Samples" self.plot_fft.x_axis.title = "Frequency pins" self.plot_fft.y_axis.title = "(dB)" # 改变fftsize为1024,因为Enum的默认缺省值为枚举列表中的第一个值 self.fftsize = 1024 # FFT频谱图的x轴上限值的改变事件处理函数,将最新的值赋值给频谱图的响应属性 def _fft_graph_up_limit_changed(self): self.plot_fft.x_axis.mapper.range.high = self.fft_graph_up_limit def _N_changed(self): self.plot_sin_combine() # 多个trait属性的改变事件处理函数相同时,可以用@on_trait_change指定 @on_trait_change("wave_width, length_c, height_c, fftsize") def update_plot(self): # 计算三角波 global y_data x_data = np.arange(0, 1.0, 1.0/self.fftsize) func = self.triangle_func() # 将func函数的返回值强制转换成float64 y_data = np.cast["float64"](func(x_data)) # 计算频谱 fft_parameters = np.fft.fft(y_data) / len(y_data) # 计算各个频率的振幅 fft_data = np.clip(20*np.log10(np.abs(fft_parameters))[:self.fftsize/2+1], -120, 120) # 将计算的结果写进数据集 self.plot_data.set_data("x", np.arange(0, self.fftsize)) # x坐标为取样点 self.plot_data.set_data("y", y_data) self.plot_data.set_data("f", np.arange(0, len(fft_data))) # x坐标为频率编号 self.plot_data.set_data("p", fft_data) # 合成波的x坐标为取样点,显示2个周期 self.plot_data.set_data("x2", np.arange(0, 2*self.fftsize)) # 更新频谱图x轴上限 self._fft_graph_up_limit_changed() # 将振幅大于-80dB的频率输出 peak_index = (fft_data > -80) peak_value = fft_data[peak_index][:20] result = [] for f, v in zip(np.flatnonzero(peak_index), peak_value): result.append("%s : %s" %(f, v) ) self.peak_list = "\n".join(result) # 保存现在的fft计算结果,并计算正弦合成波 self.fft_parameters = fft_parameters self.plot_sin_combine() # 计算正弦合成波,计算2个周期 def plot_sin_combine(self): index, data = fft_combine(self.fft_parameters, self.N, 2) self.plot_data.set_data("y2", data) # 返回一个ufunc计算指定参数的三角波 def triangle_func(self): c = self.wave_width c0 = self.length_c hc = self.height_c def trifunc(x): x = x - int(x) # 三角波的周期为1,因此只取x坐标的小数部分进行计算 if x >= c: r = 0.0 elif x < c0: r = x / c0 * hc else: r = (c-x) / (c-c0) * hc return r # 用trifunc函数创建一个ufunc函数,可以直接对数组进行计算, 不过通过此函数 # 计算得到的是一个Object数组,需要进行类型转换 return np.frompyfunc(trifunc, 1, 1) if __name__ == "__main__": triangle = TriangleWave() triangle.configure_traits() ```
';

分形与混沌

最后更新于:2022-04-01 11:15:46

# 分形与混沌 自然界的很多事物,例如树木、云彩、山脉、闪电、雪花以及海岸线等等都呈现出传统的几何学不能描述的形状。这些形状都有如下的特性: * 有着十分精细的不规则的结构 * 整体与局部相似,例如一根树杈的形状和一棵树很像 分形几何学就是用来研究这样一类的几何形状的科学,借助计算机的高速计算和图像显示,使得我们可以更加深入地直观地观察分形几何。在本章中,让我们用Python绘制一些经典的分形图案。 ## Mandelbrot集合 Mandelbrot(曼德布洛特)集合是在复平面上组成分形的点的集合。 Mandelbrot集合的定义(摘自维基百科) Mandelbrot集合可以用下面的复二次多项式定义: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bba6cbc8.png) 其中c是一个复参数。对于每一个c,从z=0开始对函数 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bba82d58.png) 进行迭代。 序列 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bba94938.png) 的值或者延伸到无限大,或者只停留在有限半径的圆盘内。 Mandelbrot集合就是使以上序列不发散的所有c点的集合。 从数学上来讲,Mandelbrot集合是一个复数的集合。一个给定的复数c或者属于Mandelbrot集合,或者不是。 用程序绘制Mandelbrot集合时不能进行无限次迭代,最简单的方法是使用逃逸时间(迭代次数)进行绘制,具体算法如下: * 判断每次调用函数 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bba82d58.png) 得到的结果是否在半径R之内,即复数的模小于R * 记录下模大于R时的迭代次数 * 迭代最多进行N次 * 不同的迭代次数的点使用不同的颜色绘制 下面是完整的绘制Mandelbrot集合的程序: ``` # -*- coding: utf-8 -*- import numpy as np import pylab as pl import time from matplotlib import cm def iter_point(c): z = c for i in xrange(1, 100): # 最多迭代100次 if abs(z)>2: break # 半径大于2则认为逃逸 z = z*z+c return i # 返回迭代次数 def draw_mandelbrot(cx, cy, d): """ 绘制点(cx, cy)附近正负d的范围的Mandelbrot """ x0, x1, y0, y1 = cx-d, cx+d, cy-d, cy+d y, x = np.ogrid[y0:y1:200j, x0:x1:200j] c = x + y*1j start = time.clock() mandelbrot = np.frompyfunc(iter_point,1,1)(c).astype(np.float) print "time=",time.clock() - start pl.imshow(mandelbrot, cmap=cm.Blues_r, extent=[x0,x1,y0,y1]) pl.gca().set_axis_off() x,y = 0.27322626, 0.595153338 pl.subplot(231) draw_mandelbrot(-0.5,0,1.5) for i in range(2,7): pl.subplot(230+i) draw_mandelbrot(x, y, 0.2**(i-1)) pl.subplots_adjust(0.02, 0, 0.98, 1, 0.02, 0) pl.show() ``` ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbaae9d4.png) Mandelbrot集合,以5倍的倍率放大点(0.273, 0.595)附近 程序中的iter_point函数计算点c的逃逸时间,逃逸半径R为2.0,最大迭代次数为100。draw_mandelbrot函数绘制以点(cx, cy)为中心,边长为2*d的正方形区域内的Mandelbrot集合。 下面3行计算指定范围内的迭代公式的参数c,c是一个元素为复数的二维数组,大小为200*200,注意np.ogrid不是函数: ``` x0, x1, y0, y1 = cx-d, cx+d, cy-d, cy+d y, x = np.ogrid[y0:y1:200j, x0:x1:200j] c = x + y*1j ``` 下面一行程序通过调用np.frompyfunc将iter_point转换为NumPy的ufunc函数,这样它可以自动对c中的每个元素调用iter_point函数,由于结果的数组元素类型为object,还需要调用astype方法将其元素类型转换为浮点类型: ``` mandelbrot = np.frompyfunc(iter_point,1,1)(c).astype(np.float) ``` 最后调用matplotlib的imshow函数将结果数组绘制成图,通过cmap关键字参数指定图的值和颜色的映射表: ``` pl.imshow(mandelbrot, cmap=cm.Blues_r, extent=[x0,x1,y0,y1]) ``` 使用Python绘制Mandelbrot集合最大的问题就是运算速度太慢,下面是上面每幅图的计算时间: ``` time= 0.88162629608 time= 1.53712748408 time= 1.71502160191 time= 1.8691174437 time= 3.03812691278 ``` 因为计算每个点的逃逸时间均不相同,因此每幅图的计算时间也不相同。 计算速度慢的最大的原因是因为iter_point函数的运算速度慢,如果将此函数用C语言重写的话将能显著地提高计算速度,下面使用scipy.weave库将C++重写的iter_point函数转换为Python能调用的函数: ``` import scipy.weave as weave def weave_iter_point(c): code = """ std::complex<double> z; int i; z = c; for(i=1;i<100;i++) { if(std::abs(z) > 2) break; z = z*z+c; } return_val=i; """ f = weave.inline(code, ["c"], compiler="gcc") return f ``` 下面是使用weave_iter_point函数计算Mandelbrot集合的时间: ``` time= 0.285266982256 time= 0.271430028118 time= 0.293769180161 time= 0.308515188383 time= 0.411168179196 ``` 通过NumPy的数组运算也可以提高计算速度,前面的计算都是先对复数平面上的每个点进行循环,然后再循环迭代计算每个点的逃逸时间。如果要用NumPy的数组运算加速计算的话,可以将这两个循环的顺序颠倒过来,下面的程序演示这一算法: ``` # -*- coding: utf-8 -*- import numpy as np import pylab as pl import time from matplotlib import cm def draw_mandelbrot(cx, cy, d, N=200): """ 绘制点(cx, cy)附近正负d的范围的Mandelbrot """ global mandelbrot x0, x1, y0, y1 = cx-d, cx+d, cy-d, cy+d y, x = np.ogrid[y0:y1:N*1j, x0:x1:N*1j] c = x + y*1j # 创建X,Y轴的坐标数组 ix, iy = np.mgrid[0:N,0:N] # 创建保存mandelbrot图的二维数组,缺省值为最大迭代次数 mandelbrot = np.ones(c.shape, dtype=np.int)*100 # 将数组都变成一维的 ix.shape = -1 iy.shape = -1 c.shape = -1 z = c.copy() # 从c开始迭代,因此开始的迭代次数为1 start = time.clock() for i in xrange(1,100): # 进行一次迭代 z *= z z += c # 找到所有结果逃逸了的点 tmp = np.abs(z) > 2.0 # 将这些逃逸点的迭代次数赋值给mandelbrot图 mandelbrot[ix[tmp], iy[tmp]] = i # 找到所有没有逃逸的点 np.logical_not(tmp, tmp) # 更新ix, iy, c, z只包含没有逃逸的点 ix,iy,c,z = ix[tmp], iy[tmp], c[tmp],z[tmp] if len(z) == 0: break print "time=",time.clock() - start pl.imshow(mandelbrot, cmap=cm.Blues_r, extent=[x0,x1,y0,y1]) pl.gca().set_axis_off() x,y = 0.27322626, 0.595153338 pl.subplot(231) draw_mandelbrot(-0.5,0,1.5) for i in range(2,7): pl.subplot(230+i) draw_mandelbrot(x, y, 0.2**(i-1)) pl.subplots_adjust(0.02, 0, 0.98, 1, 0.02, 0) pl.show() ``` 为了减少计算次数,程序中每次迭代之后,都将已经逃逸的点剔除出去,这样就需要保存每个点的下标,程序中用ix和iy这两个数组来保存没有逃逸的点的下标,因为有额外的数组保存下标,因此数组z和c不需要是二维的。函数迭代部分的程序如下: ``` # 进行一次迭代 z *= z z += c ``` 使用 [*](#id2)=, += 这样的运算符能够让NumPy不分配额外的空间直接在数组z上进行运算。 下面的程序计算出逃逸点,tmp是逃逸点在z中的下标,由于z和ix和iy等数组始终是同时更新的,因此ix[tmp], iy[tmp]就是逃逸点在图像中的下标: ``` # 找到所有结果逃逸了的点 tmp = np.abs(z) > 2.0 # 将这些逃逸点的迭代次数赋值给mandelbrot图 mandelbrot[ix[tmp], iy[tmp]] = i ``` 最后通过对tmp中的每个元素取逻辑反,更新所有没有逃逸的点的对应的ix, iy, c, z: ``` # 找到所有没有逃逸的点 np.logical_not(tmp, tmp) # 更新ix, iy, c, z只包含没有逃逸的点 ix,iy,c,z = ix[tmp], iy[tmp], c[tmp], z[tmp] ``` 此程序的计算时间如下: ``` time= 0.186070576008 time= 0.327006365334 time= 0.372756034636 time= 0.410074464771 time= 0.681048289658 time= 0.878626752841 ``` ### 连续的逃逸时间 修改逃逸半径R和最大迭代次数N,可以绘制出不同效果的Mandelbrot集合图案。但是前面所述的方法计算出的逃逸时间是大于逃逸半径时的迭代次数,因此所输出的图像最多只有N种不同的颜色值,有很强的梯度感。为了在不同的梯度之间进行渐变处理,使用下面的公式进行逃逸时间计算: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbad8774.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbae98da.png) 是迭代n次之后的结果,通过在逃逸时间的计算中引入迭代结果的模值,结果将不再是整数,而是平滑渐变的。 下面是计算此逃逸时间的程序: ``` def smooth_iter_point(c): z = c for i in xrange(1, iter_num): if abs(z)>escape_radius: break z = z*z+c absz = abs(z) if absz > 2.0: mu = i - log(log(abs(z),2),2) else: mu = i return mu # 返回正规化的迭代次数 ``` 如果你的逃逸半径设置得很小,例如2.0,那么有可能结果不够平滑,这时可以在迭代循环之后添加几次迭代保证z能够足够逃逸,例如: ``` z = z*z+c z = z*z+c i += 2 ``` 下图是逃逸半径为10,最大迭代次数为20时,绘制的结果: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbb05dd5.png) 逃逸半径=10,最大迭代次数=20的平滑处理后的Mandelbrot集合 逃逸时间公式是如何得出的? 请参考: [http://linas.org/art-gallery/escape/ray.html](http://linas.org/art-gallery/escape/ray.html) 完整的程序请参考 [_绘制Mandelbrot集合_](example_mandelbrot.html) ## 迭代函数系统(IFS) 迭代函数系统是一种用来创建分形图案的算法,它所创建的分形图永远是绝对自相似的。下面我们直接通过绘制一种蕨类植物的叶子来说明迭代函数系统的算法: 有下面4个线性函数将二维平面上的坐标进行线性映射变换: ``` 1. x(n+1)= 0 y(n+1) = 0.16 * y(n) 2. x(n+1) = 0.2 * x(n) − 0.26 * y(n) y(n+1) = 0.23 * x(n) + 0.22 * y(n) + 1.6 3. x(n+1) = −0.15 * x(n) + 0.28 * y(n) y(n+1) = 0.26 * x(n) + 0.24 * y(n) + 0.44 4. x(n+1) = 0.85 * x(n) + 0.04 * y(n) y(n+1) = −0.04 * x(n) + 0.85 * y(n) + 1.6 ``` 所谓迭代函数是指将函数的输出再次当作输入进行迭代计算,因此上面公式都是通过坐标 x(n),y(n) 计算变换后的坐标 x(n+1),y(n+1)。现在的问题是有4个迭代函数,迭代时选择哪个函数进行计算呢?我们为每个函数指定一个概率值,它们依次为1%, 7%, 7%和85%。选择迭代函数时使用通过每个函数的概率随机选择一个函数进行迭代。上面的例子中,第四个函数被选择迭代的概率最高。 最后我们从坐标原点(0,0)开始迭代,将每次迭代所得到的坐标绘制成图,就得到了叶子的分形图案。下面的程序演示这一计算过程: ``` # -*- coding: utf-8 -*- import numpy as np import matplotlib.pyplot as pl import time # 蕨类植物叶子的迭代函数和其概率值 eq1 = np.array([[0,0,0],[0,0.16,0]]) p1 = 0.01 eq2 = np.array([[0.2,-0.26,0],[0.23,0.22,1.6]]) p2 = 0.07 eq3 = np.array([[-0.15, 0.28, 0],[0.26,0.24,0.44]]) p3 = 0.07 eq4 = np.array([[0.85, 0.04, 0],[-0.04, 0.85, 1.6]]) p4 = 0.85 def ifs(p, eq, init, n): """ 进行函数迭代 p: 每个函数的选择概率列表 eq: 迭代函数列表 init: 迭代初始点 n: 迭代次数 返回值: 每次迭代所得的X坐标数组, Y坐标数组, 计算所用的函数下标 """ # 迭代向量的初始化 pos = np.ones(3, dtype=np.float) pos[:2] = init # 通过函数概率,计算函数的选择序列 p = np.add.accumulate(p) rands = np.random.rand(n) select = np.ones(n, dtype=np.int)*(n-1) for i, x in enumerate(p[::-1]): select[rands<x] = len(p)-i-1 # 结果的初始化 result = np.zeros((n,2), dtype=np.float) c = np.zeros(n, dtype=np.float) for i in xrange(n): eqidx = select[i] # 所选的函数下标 tmp = np.dot(eq[eqidx], pos) # 进行迭代 pos[:2] = tmp # 更新迭代向量 # 保存结果 result[i] = tmp c[i] = eqidx return result[:,0], result[:, 1], c start = time.clock() x, y, c = ifs([p1,p2,p3,p4],[eq1,eq2,eq3,eq4], [0,0], 100000) print time.clock() - start pl.figure(figsize=(6,6)) pl.subplot(121) pl.scatter(x, y, s=1, c="g", marker="s", linewidths=0) pl.axis("equal") pl.axis("off") pl.subplot(122) pl.scatter(x, y, s=1,c = c, marker="s", linewidths=0) pl.axis("equal") pl.axis("off") pl.subplots_adjust(left=0,right=1,bottom=0,top=1,wspace=0,hspace=0) pl.gcf().patch.set_facecolor("white") pl.show() ``` 程序中的ifs函数是进行函数迭代的主函数,我们希望通过矩阵乘法计算函数(numpy.dot)的输出,因此需要将乘法向量扩充为三维的: ``` pos = np.ones(3, dtype=np.float) pos[:2] = init ``` 这样每次和迭代函数系数进行矩阵乘积运算的向量就变成了: x(n), y(n), 1.0 。 为了减少计算时间,我们不在迭代循环中计算随机数选择迭代方程,而是事先通过每个函数的概率,计算出函数选择数组select,注意这里使用accumulate函数先将概率累加,然后产生一组0到1之间的随机数,通过判断随机数所在的概率区间选择不同的方程下标: ``` p = np.add.accumulate(p) rands = np.random.rand(n) select = np.ones(n, dtype=np.int)*(n-1) for i, x in enumerate(p[::-1]): select[rands<x] = len(p)-i-1 ``` 最后我们通过调用scatter绘图函数将所得到的坐标进行散列图绘制: ``` pl.scatter(x, y, s=1, c="g", marker="s", linewidths=0) ``` 其中每个关键字参数的含义如下: * **s** : 散列点的大小,因为我们要绘制10万点,因此大小选择为1 * **c** : 点的颜色,这里选择绿色 * **marker** : 点的形状,"s"表示正方形,方形的绘制是最快的 * **linewidths** : 点的边框宽度,0表示没有边框 此外,关键字参数c还可以传入一个数组,作为每个点的颜色值,我们将计算坐标的函数下标传入,这样可以直观地看出哪个点是哪个函数迭代产生的: ``` pl.scatter(x, y, s=1,c = c, marker="s", linewidths=0) ``` 下图是程序的输出: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbb31a12.png) 函数迭代系统所绘制的蕨类植物的叶子 观察右图的4种颜色的部分可以发现概率为1%的函数1所计算的是叶杆部分(深蓝色),概率为7%的两个函数计算的是左右两片子叶,而概率为85%的函数计算的是整个叶子的迭代:即最下面的三种颜色的点通过此函数的迭代产生上面的所有的深红色的点。 我们可以看出整个叶子呈现出完美的自相似特性,任意取其中的一个子叶,将其旋转放大之后都和整个叶子相同。 ### 2D仿射变换 上面所介绍的4个变换方程的一般形式如下: ``` x(n+1) = A * x(n) + B * y(n) + C y(n+1) = D * x(n) + E * y(n) + F ``` 这种变换被称为2D仿射变换,它是从2D坐标到其他2D坐标的线性映射,保留直线性和平行性。即原来是直线上的坐标,变换之后仍然成一条直线,原来是平行的直线,变换之后仍然是平行的。这种变换我们可以看作是一系列平移、缩放、翻转和旋转变换构成的。 为了直观地显示仿射变换,我们可以使用平面上的两个三角形来表示。因为仿射变换公式中有6个未知数:A, B, C, D, E, F,而每两个点之间的变换决定两个方程,因此一共需要3组点来决定六个变换方程,正好是两个三角形,如下图所示: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbb58fd0.png) 两个三角形决定一个2D仿射变换的六个参数 从红色三角形的每个顶点变换到绿色三角形的对应顶点,正好能够决定仿射变换中的六个参数。这样我们可是使用N+1个三角形,决定N个仿射变换,其中的每一个变换的参数都是由第0个三角形和其它的三角形决定的。这第0个三角形我们称之为基础三角形,其余的三角形称之为变换三角形。 为了绘制迭代函数系统的图像,我们还需要给每个仿射变换方程指定一个迭代概率的参数。此参数也可以使用三角形直观地表达出来:迭代概率和变换三角形的面积成正比。即迭代概率为变换三角形的面积除以所有变换三角形的面积之和。 如下图所示,前面介绍的蕨类植物的分形图案的迭代方程可以由5个三角形决定,可以很直观地看出紫色的小三角形决定了叶子的茎;而两个蓝色的三角形决定了左右两片子叶;绿色的三角形将茎和两片子叶往上复制,形成整片叶子。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbb68a98.png) 5个三角形的仿射方程绘制蕨类植物的叶子 ### 迭代函数系统设计器 按照上节所介绍的三角形法,我们可以设计一个迭代函数系统的设计工具,如下图所示: <object classid="clsid:D27CDB6E-AE6D-11cf-96B8-444553540000" width="600" height="370" codebase="http://active.macromedia.com/flash5/cabs/swflash.cab#version=7,0,0,0"><param name="movie" value="img/ifs.swf"> <param name="play" value="true"> <param name="loop" value="false"> <param name="wmode" value="transparent"> <param name="quality" value="high"> <embed src="img/ifs.swf" width="600" height="370" quality="high" loop="false" wmode="transparent" type="application/x-shockwave-flash" pluginspage="http://www.macromedia.com/shockwave/download/index.cgi?P1_Prod_Version=ShockwaveFlash"> </object> ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbb83791.swf) 具体的程序请参照 [_迭代函数系统的分形_](example_ifs.html) ,这里简单地介绍一下程序的几个核心组成部分: 首先通过两个三角形求解仿射方程的系数相当于求六元线性方程组的解,solve_eq函数完成这一工作,它先计算出线性方程组的矩阵a和b, 然后调用NumPy的linalg.solve对线性方程组 a*X = b 求解: ``` def solve_eq(triangle1, triangle2): """ 解方程,从triangle1变换到triangle2的变换系数 triangle1,2是二维数组: x0,y0 x1,y1 x2,y2 """ x0,y0 = triangle1[0] x1,y1 = triangle1[1] x2,y2 = triangle1[2] a = np.zeros((6,6), dtype=np.float) b = triangle2.reshape(-1) a[0, 0:3] = x0,y0,1 a[1, 3:6] = x0,y0,1 a[2, 0:3] = x1,y1,1 a[3, 3:6] = x1,y1,1 a[4, 0:3] = x2,y2,1 a[5, 3:6] = x2,y2,1 c = np.linalg.solve(a, b) c.shape = (2,3) return c ``` triangle_area函数计算三角形的面积,它使用NumPy的cross函数计算三角形的两个边的向量的叉积: ``` def triangle_area(triangle): """ 计算三角形的面积 """ A = triangle[0] B = triangle[1] C = triangle[2] AB = A-B AC = A-C return np.abs(np.cross(AB,AC))/2.0 ``` 整个程序的界面使用TraitsUI库生成,将matplotlib的Figure控件通过MPLFigureEditor和_MPLFigureEditor类嵌入到TraitsUI生成的界面中,请参考: [_设计自己的Trait编辑器_](traitsui_manual_custom_editor.html) IFSDesigner._figure_default创建Figure对象,并且添加两个并排的子图ax和ax2,ax用于三角形编辑,而ax2用于分形图案显示。 ``` def _figure_default(self): """ figure属性的缺省值,直接创建一个Figure对象 """ figure = Figure() self.ax = figure.add_subplot(121) self.ax2 = figure.add_subplot(122) self.ax2.set_axis_off() self.ax.set_axis_off() figure.subplots_adjust(left=0,right=1,bottom=0,top=1,wspace=0,hspace=0) figure.patch.set_facecolor("w") return figure ``` IFSTriangles类完成三角形的编辑工作,其中通过如下的语句绑定Figure控件的canvas的鼠标事件 ``` canvas = ax.figure.canvas # 绑定canvas的鼠标事件 canvas.mpl_connect('button_press_event', self.button_press_callback) canvas.mpl_connect('button_release_event', self.button_release_callback) canvas.mpl_connect('motion_notify_event', self.motion_notify_callback) ``` 由于canvas只有在真正显示Figure时才会创建,因此不能在创建Figure控件时创建IFSTriangles对象,而需要在界面生成之后,显示之前创建它。这里我们通过给IFSDesigner类的view属性指定其handler为IFSHandler对象,重载Handler的init方法,此方法在界面生成之后,显示之前被调用: ``` class IFSHandler(Handler): """ 在界面显示之前需要初始化的内容 """ def init(self, info): info.object.init_gui_component() return True ``` 然后IFSDesigner类的init_gui_component方法完成实际和canvas相关的初始工作: ``` def init_gui_component(self): self.ifs_triangle = IFSTriangles(self.ax) self.figure.canvas.draw() thread.start_new_thread( self.ifs_calculate, ()) ... ``` 由于通过函数迭代计算分形图案比较费时,因此在另外一个线程中执行ifs_calculate方法进行运算,每计算ITER_COUNT个点,就调用ax2.scatter将产生的点添加进ax2中,由于随着ax2中的点数增加,界面重绘将越来越慢,因此在draw_points函数中限制最多只调用ITER_TIMES次scatter函数。因为在别的线程中不能更新界面,因此通过调用wx.CallAfter在管理GUI的线程中调用draw_points进行界面刷新。: ``` def ifs_calculate(self): """ 在别的线程中计算 """ def draw_points(x, y, c): if len(self.ax2.collections) < ITER_TIMES: try: self.ax2.scatter(x, y, s=1, c=c, marker="s", linewidths=0) self.ax2.set_axis_off() self.ax2.axis("equal") self.figure.canvas.draw() except: pass def clear_points(): self.ax2.clear() while 1: try: if self.exit == True: break if self.clear == True: self.clear = False self.initpos = [0, 0] x, y, c = ifs( self.ifs_triangle.get_areas(), self.ifs_triangle.get_eqs(), self.initpos, 100) self.initpos = [x[-1], y[-1]] self.ax2.clear() x, y, c = ifs( self.ifs_triangle.get_areas(), self.ifs_triangle.get_eqs(), self.initpos, ITER_COUNT) if np.max(np.abs(x)) < 1000000 and np.max(np.abs(y)) < 1000000: self.initpos = [x[-1], y[-1]] wx.CallAfter( draw_points, x, y, c ) time.sleep(0.05) except: pass ``` 用户修改三角形之后,需要重新迭代,并绘制分形图案,三角形的改变通过 IFSTriangles.version 属性通知给IFSDesigner,在IFSTriangles中,三角形改变之后,将运行: ``` self.version += 1 ``` 在IFSDesigner中监听version属性的变化: ``` @on_trait_change("ifs_triangle.version") def on_ifs_version_changed(self): """ 当三角形更新时,重新绘制所有的迭代点 """ self.clear = True ``` 当IFSDesigner.clear为True时,真正进行迭代运算的ifs_calculate方法就知道需要重新计算了。 ## L-System分形 前面所绘制的分形图案都是都是使用数学函数的迭代产生,而L-System分形则是采用符号的递归迭代产生。首先如下定义几个有含义的符号: * **F** : 向前走固定单位 * **+** : 正方向旋转固定单位 * **-** : 负方向旋转固定单位 使用这三个符号我们很容易描述下图中由4条线段构成的图案: ``` F+F--F+F ``` 如果将此符号串中的所有F都替换为F+F--F+F,就能得到如下的新字符串: ``` F+F--F+F+F+F--F+F--F+F--F+F+F+F--F+F ``` 如此替换迭代下去,并根据字串进行绘图(符号+和-分别正负旋转60度),可得到如下的分形图案: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbbc9988.png) 使用F+F--F+F迭代的分形图案 除了 F, +, - 之外我们再定义如下几个符号: * **f** : 向前走固定单位,为了定义不同的迭代公式 * **[** : 将当前的位置入堆栈 * **]** : 从堆栈中读取坐标,修改当前位置 * **S** : 初始迭代符号 所有的符号(包括上面未定义的)都可以用来定义迭代,通过引入两个方括号符号,使得我们能够描述分岔的图案。 例如下面的符号迭代能够绘制出一棵植物: ``` S -> X X -> F-[[X]+X]+F[+FX]-X F -> FF ``` 我们用一个字典定义所有的迭代公式和其它的一些绘图信息: ``` { "X":"F-[[X]+X]+F[+FX]-X", "F":"FF", "S":"X", "direct":-45, "angle":25, "iter":6, "title":"Plant" } ``` 其中: * **direct** : 是绘图的初始角度,通过指定不同的值可以旋转整个图案 * **angle** : 定义符号+,-旋转时的角度,不同的值能产生完全不同的图案 * **iter** : 迭代次数 下面的程序将上述字典转换为需要绘制的线段坐标: ``` class L_System(object): def __init__(self, rule): info = rule['S'] for i in range(rule['iter']): ninfo = [] for c in info: if c in rule: ninfo.append(rule[c]) else: ninfo.append(c) info = "".join(ninfo) self.rule = rule self.info = info def get_lines(self): d = self.rule['direct'] a = self.rule['angle'] p = (0.0, 0.0) l = 1.0 lines = [] stack = [] for c in self.info: if c in "Ff": r = d * pi / 180 t = p[0] + l*cos(r), p[1] + l*sin(r) lines.append(((p[0], p[1]), (t[0], t[1]))) p = t elif c == "+": d += a elif c == "-": d -= a elif c == "[": stack.append((p,d)) elif c == "]": p, d = stack[-1] del stack[-1] return lines ``` 我们使用matplotlib的LineCollection绘制所有的直线: ``` import matplotlib.pyplot as pl from matplotlib import collections # rule = {...} 此处省略rule的定义 lines = L_System(rule).get_lines() fig = pl.figure() ax = fig.add_subplot(111) linecollections = collections.LineCollection(lines) ax.add_collection(linecollections, autolim=True) pl.show() ``` 下面是几种L-System的分形图案,绘制此图的完整程序请参照 [_绘制L-System的分形图_](example_lsystem.html) 。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bbbe10d4.png) 几种L-System的迭代图案
';

单摆和双摆模拟

最后更新于:2022-04-01 11:15:43

# 单摆和双摆模拟 ## 单摆模拟 由一根不可伸长、质量不计的绳子,上端固定,下端系一质点,这样的装置叫做单摆。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb7ba346.png) 单摆装置示意图 根据牛顿力学定律,我们可以列出如下微分方程: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb7c8d0b.png) 其中 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb7db44c.png) 为单摆的摆角, ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb7e894c.png) 为单摆的长度, g为重力加速度。 此微分方程的符号解无法直接求出,因此只能调用odeint对其求数值解。 odeint函数的调用参数如下: ``` odeint(func, y0, t, ...) ``` 其中func是一个Python的函数对象,用来计算微分方程组中每个未知函数的导数,y0为微分方程组中每个未知函数的初始值,t为需要进行数值求解的时间点。它返回的是一个二维数组result,其第0轴的长度为t的长度,第1轴的长度为变量的个数,因此 result[:, i] 为第i个未知函数的解。 计算微分的func函数的调用参数为: func(y, t),其中y是一个数组,为每个未知函数在t时刻的值,而func的返回值也是数组,它为每个未知函数在t时刻的导数。 odeint要求每个微分方程只包含一阶导数,因此我们需要对前面的微分方程做如下的变形: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb8046cd.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb81648c.png) 下面是利用odeint计算单摆轨迹的程序: ``` # -*- coding: utf-8 -*- from math import sin import numpy as np from scipy.integrate import odeint g = 9.8 def pendulum_equations(w, t, l): th, v = w dth = v dv = - g/l * sin(th) return dth, dv if __name__ == "__main__": import pylab as pl t = np.arange(0, 10, 0.01) track = odeint(pendulum_equations, (1.0, 0), t, args=(1.0,)) pl.plot(t, track[:, 0]) pl.title(u"单摆的角度变化, 初始角度=1.0弧度") pl.xlabel(u"时间(秒)") pl.ylabel(u"震度角度(弧度)") pl.show() ``` odeint函数还有一个关键字参数args,其值为一个组元,这些值都会作为额外的参数传递给func函数。程序使用这种方式将单摆的长度传递给pendulum_equations函数。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb82458d.png) 初始角度为1弧度的单摆摆动角度和时间的关系 ### 计算摆动周期 高中物理课介绍过当最大摆动角度很小时,单摆的摆动周期可以使用如下公式计算: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb83791f.png) 这是因为当 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb847104.png) 时, ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb85f9ed.png) , 这样微分方程就变成了: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb86fcec.png) 此微分方程的解是一个简谐震动方程,很容易计算其摆动周期。但是当初始摆角增大时,上述的近似处理会带来无法忽视的误差。下面让我们来看看如何用数值计算的方法求出单摆在任意初始摆角时的摆动周期。 要计算摆动周期只需要计算从最大摆角到0摆角所需的时间,摆动周期是此时间的4倍。为了计算出这个时间值,首先需要定义一个函数pendulum_th计算任意时刻的摆角: ``` def pendulum_th(t, l, th0): track = odeint(pendulum_equations, (th0, 0), [0, t], args=(l,)) return track[-1, 0] ``` pendulum_th函数计算长度为l初始角度为th0的单摆在时刻t的摆角。此函数仍然使用odeint进行微分方程组求解,只是我们只需要计算时刻t的摆角,因此传递给odeint的时间序列为[0, t]。 odeint内部会对时间进行细分,保证最终的解是正确的。 接下来只需要找到第一个时pendulum_th的结果为0的时间即可。这相当于对pendulum_th函数求解,可以使用 scipy.optimize.fsolve 函数对这种非线性方程进行求解。 ``` def pendulum_period(l, th0): t0 = 2*np.pi*sqrt( l/g ) / 4 t = fsolve( pendulum_th, t0, args = (l, th0) ) return t*4 ``` 和odeint一样,我们通过fsolve的args关键字参数将额外的参数传递给pendulum_th函数。fsolve求解时需要一个初始值尽量接近真实的解,用小角度单摆的周期的1/4作为这个初始值是一个很不错的选择。下面利用pendulum_period函数计算出初始摆动角度从0到90度的摆动周期: ``` ths = np.arange(0, np.pi/2.0, 0.01) periods = [pendulum_period(1, th) for th in ths] ``` 为了验证fsolve求解摆动周期的正确性,我从维基百科中找到摆动周期的精确解: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb87ea76.png) 其中的函数K为第一类完全椭圆积分函数,其定义如下: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb88dc50.png) 我们可以用 scipy.special.ellipk 来计算此函数的值: ``` periods2 = 4*sqrt(1.0/g)*ellipk(np.sin(ths/2)**2) ``` 下图比较两种计算方法,我们看到其结果是完全一致的: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb89b52f.png) 单摆的摆动周期和初始角度的关系 完整的程序请参见: [_单摆摆动周期的计算_](example_simple_pendulum_period.html) ## 双摆模拟 接下来让我们来看看如何对双摆系统进行模拟。双摆系统的如下图所示, ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb8ae047.png) 双摆装置示意图 两根长度为L1和L2的无质量的细棒的顶端有质量分别为m1和m2的两个球,初始角度为 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb8bd43b.png) 和 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb8cbe8a.png) , 要求计算从此初始状态释放之后的两个球的运动轨迹。 ### 公式推导 本节首先介绍如何利用拉格朗日力学获得双摆系统的微分方程组。 拉格朗日力学(摘自维基百科) 拉格朗日力学是分析力学中的一种。於 1788 年由拉格朗日所创立,拉格朗日力学是对经典力学的一种的新的理论表述。 经典力学最初的表述形式由牛顿建立,它着重於分析位移,速度,加速度,力等矢量间的关系,又称为矢量力学。拉格朗日引入了广义坐标的概念,又运用达朗贝尔原理,求得与牛顿第二定律等价的拉格朗日方程。不仅如此,拉格朗日方程具有更普遍的意义,适用范围更广泛。还有,选取恰当的广义坐标,可以大大地简化拉格朗日方程的求解过程。 假设杆L1连接的球体的坐标为x1和y1,杆L2连接的球体的坐标为x2和y2,那么x1,y1,x2,y2和两个角度之间有如下关系: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb8da6b7.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb8e96de.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb903360.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb9138c1.png) 根据拉格朗日量的公式: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb92399a.png) 其中T为系统的动能,V为系统的势能,可以得到如下公式: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb93188f.png) 其中正号的项目为两个小球的动能,符号的项目为两个小球的势能。 将前面的坐标和角度之间的关系公式带入之后整理可得: ![\mathcal{L} = \frac{m_1+m_2}{2} L_1^2 {\dot \theta_1}^2 + \frac{m_2}{2} L_2^2 {\dot \theta_2}^2 + m_2 L_1 L_2 {\dot \theta_1} {\dot \theta_2} \cos(\theta_1 - \theta_2) + (m_1 + m_2) g L_1 \cos(\theta_1) + m_2 g L_2 \cos(\theta_2)](img/ddb810d6eed716ca3b15c0a3e9c94185accee394.png) 对于变量 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb8bd43b.png) 的拉格朗日方程: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb945ade.png) 得到: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb954dd3.png) 对于变量 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb8cbe8a.png) 的拉格朗日方程: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb968801.png) 得到: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb975ad8.png) 这一计算过程可以用sympy进行推导: ``` # -*- coding: utf-8 -*- from sympy import * from sympy import Derivative as D var("x1 x2 y1 y2 l1 l2 m1 m2 th1 th2 dth1 dth2 ddth1 ddth2 t g tmp") sublist = [ (D(th1(t), t, t), ddth1), (D(th1(t), t), dth1), (D(th2(t), t, t), ddth2), (D(th2(t),t), dth2), (th1(t), th1), (th2(t), th2) ] x1 = l1*sin(th1(t)) y1 = -l1*cos(th1(t)) x2 = l1*sin(th1(t)) + l2*sin(th2(t)) y2 = -l1*cos(th1(t)) - l2*cos(th2(t)) vx1 = diff(x1, t) vx2 = diff(x2, t) vy1 = diff(y1, t) vy2 = diff(y2, t) # 拉格朗日量 L = m1/2*(vx1**2 + vy1**2) + m2/2*(vx2**2 + vy2**2) - m1*g*y1 - m2*g*y2 # 拉格朗日方程 def lagrange_equation(L, v): a = L.subs(D(v(t), t), tmp).diff(tmp).subs(tmp, D(v(t), t)) b = L.subs(D(v(t), t), tmp).subs(v(t), v).diff(v).subs(v, v(t)).subs(tmp, D(v(t), t)) c = a.diff(t) - b c = c.subs(sublist) c = trigsimp(simplify(c)) c = collect(c, [th1,th2,dth1,dth2,ddth1,ddth2]) return c eq1 = lagrange_equation(L, th1) eq2 = lagrange_equation(L, th2) ``` 执行此程序之后,eq1对应于 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb8bd43b.png) 的拉格朗日方程, eq2对应于 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb8cbe8a.png) 的方程。 由于sympy只能对符号变量求导数,即只能计算 D(L, t), 而不能计算D(f, v(t))。 因此在求偏导数之前,将偏导数变量置换为一个tmp变量,然后对tmp变量求导数,例如下面的程序行对D(v(t), t)求偏导数,即计算 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb98b30f.png) ``` L.subs(D(v(t), t), tmp).diff(tmp).subs(tmp, D(v(t), t)) ``` 而在计算 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb9989e5.png) 时,需要将v(t)替换为v之后再进行微分计算。由于将v(t)替换为v的同时,会将 D(v(t), t) 中的也进行替换,这是我们不希望的结果,因此先将 D(v(t), t) 替换为tmp,微分计算完毕之后再替换回去: ``` L.subs(D(v(t), t), tmp).subs(v(t), v).diff(v).subs(v, v(t)).subs(tmp, D(v(t), t)) ``` 最后得到的eq1, eq2的值为: ``` >>> eq1 ddth1*(m1*l1**2 + m2*l1**2) + ddth2*(l1*l2*m2*cos(th1)*cos(th2) + l1*l2*m2*sin(th1)*sin(th2)) + dth2**2*(l1*l2*m2*cos(th2)*sin(th1) - l1*l2*m2*cos(th1)*sin(th2)) + g*l1*m1*sin(th1) + g*l1*m2*sin(th1) >>> eq2 ddth1*(l1*l2*m2*cos(th1)*cos(th2) + l1*l2*m2*sin(th1)*sin(th2)) + dth1**2*(l1*l2*m2*cos(th1)*sin(th2) - l1*l2*m2*cos(th2)*sin(th1)) + g*l2*m2*sin(th2) + ddth2*m2*l2**2 ``` 结果看上去挺复杂,其实只要运用如下的三角公式就和前面的结果一致了: ![\sin \left(x+y\right)=\sin x \cos y + \cos x \sin y \cos \left(x+y\right)=\cos x \cos y - \sin x \sin y \sin \left(x-y\right)=\sin x \cos y - \cos x \sin y \cos \left(x-y\right)=\cos x \cos y + \sin x \sin y](img/5149d766733feeccc71ce88a0b4d99749b762842.png) ### 微分方程的数值解 接下来要做的事情就是对如下的微分方程求数值解: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb9a873c.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb9b8376.png) 由于方程中包含二阶导数,因此无法直接使用odeint函数进行数值求解,我们很容易将其改写为4个一阶微分方程组,4个未知变量为: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb9c70bd.png) , 其中 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb9d67ab.png) 为两个杆转动的角速度。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb9e5583.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bba0099f.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bba106ca.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bba1d3db.png) 下面的程序利用 scipy.integrate.odeint 对此微分方程组进行数值求解: ``` # -*- coding: utf-8 -*- from math import sin,cos import numpy as np from scipy.integrate import odeint g = 9.8 class DoublePendulum(object): def __init__(self, m1, m2, l1, l2): self.m1, self.m2, self.l1, self.l2 = m1, m2, l1, l2 self.init_status = np.array([0.0,0.0,0.0,0.0]) def equations(self, w, t): """ 微分方程公式 """ m1, m2, l1, l2 = self.m1, self.m2, self.l1, self.l2 th1, th2, v1, v2 = w dth1 = v1 dth2 = v2 #eq of th1 a = l1*l1*(m1+m2) # dv1 parameter b = l1*m2*l2*cos(th1-th2) # dv2 paramter c = l1*(m2*l2*sin(th1-th2)*dth2*dth2 + (m1+m2)*g*sin(th1)) #eq of th2 d = m2*l2*l1*cos(th1-th2) # dv1 parameter e = m2*l2*l2 # dv2 parameter f = m2*l2*(-l1*sin(th1-th2)*dth1*dth1 + g*sin(th2)) dv1, dv2 = np.linalg.solve([[a,b],[d,e]], [-c,-f]) return np.array([dth1, dth2, dv1, dv2]) def double_pendulum_odeint(pendulum, ts, te, tstep): """ 对双摆系统的微分方程组进行数值求解,返回两个小球的X-Y坐标 """ t = np.arange(ts, te, tstep) track = odeint(pendulum.equations, pendulum.init_status, t) th1_array, th2_array = track[:,0], track[:, 1] l1, l2 = pendulum.l1, pendulum.l2 x1 = l1*np.sin(th1_array) y1 = -l1*np.cos(th1_array) x2 = x1 + l2*np.sin(th2_array) y2 = y1 - l2*np.cos(th2_array) pendulum.init_status = track[-1,:].copy() #将最后的状态赋给pendulum return [x1, y1, x2, y2] if __name__ == "__main__": import matplotlib.pyplot as pl pendulum = DoublePendulum(1.0, 2.0, 1.0, 2.0) th1, th2 = 1.0, 2.0 pendulum.init_status[:2] = th1, th2 x1, y1, x2, y2 = double_pendulum_odeint(pendulum, 0, 30, 0.02) pl.plot(x1,y1, label = u"上球") pl.plot(x2,y2, label = u"下球") pl.title(u"双摆系统的轨迹, 初始角度=%s,%s" % (th1, th2)) pl.legend() pl.axis("equal") pl.show() ``` 程序中的 DoublePendulum.equations 函数计算各个未知函数的导数,其输入参数w数组中的变量依次为: * th1: 上球角度 * th2: 下球角度 * v1: 上球角速度 * v2: 下球角速度 返回值为每个变量的导数: * dth1: 上球角速度 * dth2: 下球角速度 * dv1: 上球角加速度 * dv2: 下球角加速度 其中dth1和dth2很容易计算,它们直接等于传入的角速度变量: ``` dth1 = v1 dth2 = v2 ``` 为了计算dv1和dv2,需要将微分方程组进行变形为如下格式 : ![\dot v_1 = ... \dot v_2 = ...](img/47827c6cd286f5e2e5d906f8e99385366e0db189.png) 如果我们希望让程序做这个事情的话,可以计算出 dv1 和 dv2 的系数,然后调用 linalg.solve 求解线型方程组: ``` #eq of th1 a = l1*l1*(m1+m2) # dv1 parameter b = l1*m2*l2*cos(th1-th2) # dv2 paramter c = l1*(m2*l2*sin(th1-th2)*dth2*dth2 + (m1+m2)*g*sin(th1)) #eq of th2 d = m2*l2*l1*cos(th1-th2) # dv1 parameter e = m2*l2*l2 # dv2 parameter f = m2*l2*(-l1*sin(th1-th2)*dth1*dth1 + g*sin(th2)) dv1, dv2 = np.linalg.solve([[a,b],[d,e]], [-c,-f]) ``` 上面的程序相当于将原始的微分方程组变换为 ![a \dot v_1 + b \dot v_2 + c = 0 d \dot v_1 + e \dot v_2 + f = 0](img/0adfbc12bbe1ff9d4511bc29a532781e32e7609e.png) 程序绘制的小球运动轨迹如下: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bba2d608.png) 初始角度微小时的双摆的摆动轨迹 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bba3f9f8.png) 大初始角度时双摆的摆动轨迹呈现混沌现象 可以看出当初始角度很大的时候,摆动出现混沌现象。 ### 动画显示 计算出小球的轨迹之后我们很容易将结果可视化,制作成动画效果。制作动画可以有多种选择: * visual库可以制作3D动画 * pygame制作快速的2D动画 * tkinter或者wxpython直接在界面上绘制动画 这里介绍如何使用matplotlib制作动画。整个动画绘制程序如下: ``` # -*- coding: utf-8 -*- import matplotlib matplotlib.use('WXAgg') # do this before importing pylab import matplotlib.pyplot as pl from double_pendulum_odeint import double_pendulum_odeint, DoublePendulum fig = pl.figure(figsize=(4,4)) line1, = pl.plot([0,0], [0,0], "-o") line2, = pl.plot([0,0], [0,0], "-o") pl.axis("equal") pl.xlim(-4,4) pl.ylim(-4,2) pendulum = DoublePendulum(1.0, 2.0, 1.0, 2.0) pendulum.init_status[:] = 1.0, 2.0, 0, 0 x1, y1, x2, y2 = [],[],[],[] idx = 0 def update_line(event): global x1, x2, y1, y2, idx if idx == len(x1): x1, y1, x2, y2 = double_pendulum_odeint(pendulum, 0, 1, 0.05) idx = 0 line1.set_xdata([0, x1[idx]]) line1.set_ydata([0, y1[idx]]) line2.set_xdata([x1[idx], x2[idx]]) line2.set_ydata([y1[idx], y2[idx]]) fig.canvas.draw() idx += 1 import wx id = wx.NewId() actor = fig.canvas.manager.frame timer = wx.Timer(actor, id=id) timer.Start(1) wx.EVT_TIMER(actor, id, update_line) pl.show() ``` 程序中强制使用WXAgg进行后台绘制: ``` matplotlib.use('WXAgg') ``` 然后启动wx库中的时间事件调用update_line函数重新设置两条直线的端点位置: ``` import wx id = wx.NewId() actor = fig.canvas.manager.frame timer = wx.Timer(actor, id=id) timer.Start(1) wx.EVT_TIMER(actor, id, update_line) ``` 在update_line函数中,每次轨迹数组播放完毕之后,就调用: ``` if idx == len(x1): x1, y1, x2, y2 = double_pendulum_odeint(pendulum, 0, 1, 0.05) idx = 0 ``` 重新生成下一秒钟的轨迹。由于在 double_pendulum_odeint 函数中会将odeint计算的最终的状态赋给 pendulum.init_status ,因此连续调用 double_pendulum_odeint 函数可以生成连续的运动轨迹 ``` def double_pendulum_odeint(pendulum, ts, te, tstep): ... track = odeint(pendulum.equations, pendulum.init_status, t) ... pendulum.init_status = track[-1,:].copy() return [x1, y1, x2, y2] ``` 程序的动画效果如下图所示: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bba5b0d9.png) 双摆的摆动动画效果截图
';

自适应滤波器和NLMS模拟

最后更新于:2022-04-01 11:15:41

# 自适应滤波器和NLMS模拟 本章将简要介绍自适应滤波器的原理以及其最常用的算法NLMS,并给NLMS算法的两种实现方法:用纯Python编写,和用ctypes调用C语言编写。最后将对NLMS算法进行一些的实验。 ## 自适应滤波器简介 近年来,随着数字信号处理器的功能的不断增强,自适应信号处理 (adaptive signal process)活跃在噪声消除、回声控制、信号预测、声音定位等众多信号处理领域。 尽管其应用领域十分广泛,但基本的系统构造大致只有如下几种分类。 ### 系统辨识 所谓系统辨识(system identification),就是通过对未知系统的输入输出进行观测,构造一个滤波器使得它在同样的输入的情况下,输出信号和未知系统相同。简而言之,就是通过观测未知系统对输入的反应,探知其内部情况。为了探知内情而使用的输入信号我们称之为参照信号。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb6664a8.png) 系统识别(System Identification)的框图 如上图所示参照信号 x(j)同时输入到未知系统和自适应滤波器H中,未知系统的输出为y(j), 自适应滤波器的输出为u(j),由于观测误差或者外部噪声的干扰,实际观测到的未知系统的输出为d(j)=y(j)+n(j),n(j)被称为外部干扰。通过求的d(j)和u(j)之间的误差e(j)=d(j)-u(j),我们可以知道自适应滤波器H和未知系统还有多少差别,通过这个误差我们更新H的内部参数,使得它更加靠近未知系统。 上面各个公式中的j表示某一时刻,因为我们讨论的是数字信号处理,已经对所有的信号进行取样,因此可以把j简单的看作取样点的下标。 ### 信号预测 所谓信号预测就是通过信号过去的值预测(计算)现在的值,下面是信号预测的系统框图。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb673543.png) 信号预测(Predication)框图 x(j)是待预测的信号,假设我们无法完美地观测此信号,因此导入一个外部干扰n(j),这样d(j)=x(j)+n(j)就是我们观测到的待预测信号。 通过延迟器将d(j)进行延时得到d(j-D),并把d(j-D)输入到自适应滤波器H中,得到其输出为u(j),u(j)就是自适应滤波器通过待预测信号过去的值预测出的现在的值,计算观测值d(j)和预测值u(j)之间的误差e(j)=d(j)-u(j),通过e(j)更新自适应滤波器H的内部系数使得其输出更加接近d(j)。 如果x(j)存在白色噪声的成分和周期信号的成分,由于白色噪声是完全不自相关,无法预测的信号,因此通过过去的值x(j-D)所能预测的只能是其中的周期信号的成分。这样自适应滤波器H的输出信号u(j)就会与周期信号成分渐渐逼近,而e(j)则是剩下的不可预测的白色噪声的成分。因此自适应滤波器也可以运用于噪声消除。 ### 信号均衡 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb681a4b.png) 信号均衡(Equalization)框图 当信号x(j)通过未知系统之后变成y(j),未知系统对信号x(j)进行了某种改变,使得其波形产生歪曲。我们希望均衡器矫正这种歪曲,也就是通过y(j)重建原始信号x(j),由于因果律还原原始信号x(j)是不可能的,我们只能还原其延时了的信号x(j-D)。x(j)和x(j-D)除了时间上的延迟之外,其它特性完全相同。 这里我们将观测到的未知系统的输出y(j)+n(j)输入到自适应滤波器H中,通过H的系数更新使得其输出u(j)逐渐逼近原始信号的延时x(j-D)。这样我们就构建了一个滤波器H使得它与未知系统的卷积正好等于一个脉冲传递函数。也就是说H的频域特性恰好能抵消未知系统的所带来的改变。 ## NLMS计算公式 自适应滤波器中最重要的一个环节就是其系数的更新算法,如果不对自适应滤波器的系数更新的话,那么它就只是一个普通的FIR滤波器了。系数更新算法有很多种类,最基本、常用、简单的一种方法叫做NLMS(归一化最小均方),让我们先来看看它的数学公式表达: 设置自适应滤波器系数 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb690b5e.png) 的所有初始值为0, ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb690b5e.png) 的长度为I。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb6a217d.png) 对每个取样值进行如下计算,其中n=0, 1, 2, ... ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb6b03cb.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb6c53a9.png) ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb6dd37d.png) 自适应滤波器系数 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb690b5e.png) 是一个长度为I的矢量,也就是一个长度为I的FIR滤波器。在时刻n,滤波器的每个系数对应的输入信号为 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb6eea00.png) ,它也是一个长度为I的矢量。这两个矢量的点乘即为滤波器的输出和目标信号d(n)之间的差为e(n),然后根据e(n)和 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb6eea00.png) , 更新滤波器的系数。 数学公式总是令人难以理解的,下面我们以图示为例进行说明: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb70c21e.png) NLMS算法示意图 图中假设自适应滤波器h的长度为4,在时刻7滤波器的输出为: ``` u[7] = h[0]*x[7] + h[1]*x[6] + h[2]*x[5] + h[3]*x[4] ``` 滤波器的输入信号的平方和powerX为: ``` powerX = x[4]*x[4] + x[5]*x[5] + x[6]*x[6] + x[7]*x[7] ``` 未知系统的输出d[7]和滤波器的输出u[7]之间的差为: ``` e[7] = d[7] - u[7] ``` 使用u[7]和x[4]..x[7]对滤波器的系数更新: ``` h[4] = h[4] + u * e[7]*x[4]/powerX h[4] = h[5] + u * e[7]*x[5]/powerX h[4] = h[6] + u * e[7]*x[6]/powerX h[4] = h[7] + u * e[7]*x[7]/powerX ``` 其中参数u成为更新系数,为0到1之间的一个实数,此值越大系数更新的速度越快。对于每个时刻i都需要进行上述的计算,因此滤波器的系数对于每个参照信号x的取样都更新一次。 ## NumPy实现 按照上面介绍的NLMS算法,我们很容易写出用NumPy实现的NLMS计算程序: ``` # -*- coding: utf-8 -*- # filename: nlms_numpy.py import numpy as np # 用Numpy实现的NLMS算法 # x为参照信号,d为目标信号,h为自适应滤波器的初值 # step_size为更新系数 def nlms(x, d, h, step_size=0.5): i = len(h) size = len(x) # 计算输入到h中的参照信号的乘方he power = np.sum( x[i:i-len(h):-1] * x[i:i-len(h):-1] ) u = np.zeros(size, dtype=np.float64) while True: x_input = x[i:i-len(h):-1] u[i] = np.dot(x_input , h) e = d[i] - u[i] h += step_size * e / power * x_input power -= x_input[-1] * x_input[-1] # 减去最早的取样 i+=1 if i >= size: return u power += x[i] * x[i] # 增加最新的取样 ``` 为了节省计算时间,我们用一个临时变量power保存输入到滤波器h中的参照信号x的能量。在对于x中的每个取样的循环中,power减去x中最早的一个取样值的乘方,增加最新的一个取样值的乘方。这样为了计算参照信号的能量,每次循环只需要计算两次乘法和两次加法即可。 nlms函数的输入为参照信号x、目标信号d和自适应滤波器的系数h。因为在后面的模拟计算中,d是x和未知系统的脉冲响应的卷积而计算的来,它的长度会大于x的参数,因此循环体的循环次数以参照信号的长度为基准。 为了对自适应滤波器的各种应用进行模拟,我们还需要如下的几个辅助函数,完整的程序请参考 [_NLMS算法的模拟测试_](example_nlms_test.html) 。 ``` def make_path(delay, length): path_length = length - delay h = np.zeros(length, np.float64) h[delay:] = np.random.standard_normal(path_length) * np.exp( np.linspace(0, -4, path_length) ) h /= np.sqrt(np.sum(h*h)) return h ``` make_path产生一个长度为length,最小延时为delay的指数衰减的波形。这种波形和封闭空间的声音的传递函数有些类似之处,因此在计算机上进行声音的算法模拟时经常用这种波形作为系统的传递函数。: ``` def plot_converge(y, u, label=""): size = len(u) avg_number = 200 e = np.power(y[:size] - u, 2) tmp = e[:int(size/avg_number)*avg_number] tmp.shape = -1, avg_number avg = np.average( tmp, axis=1 ) pl.plot(np.linspace(0, size, len(avg)), 10*np.log10(avg), linewidth=2.0, label=label) def diff_db(h0, h): return 10*np.log10(np.sum((h0-h)*(h0-h)) / np.sum(h0*h0)) ``` plot_converge绘制信号y和信号u之间的误差,每avg_number个取样点就上一次误差的乘方的平均值。我们将用plot_converge函数绘制未知系统的输出y和自适应滤波器的输出u之间的误差。观察自适应滤波器是如何收敛的,以评价自适应滤波器的收敛特性。diff_db函数同样是用来评价自适应滤波器的收敛特性,不过他是直接计算未知系统的传递函数h0和自适应滤波器的传递函数h之间的误差。下面我们会看到这两个函数得到的收敛值是相同的。 ### 系统辨识模拟 我们用下面的函数调用nlms算法对系统辨识应用进行模拟: ``` def sim_system_identify(nlms, x, h0, step_size, noise_scale): y = np.convolve(x, h0) d = y + np.random.standard_normal(len(y)) * noise_scale # 添加白色噪声的外部干扰 h = np.zeros(len(h0), np.float64) # 自适应滤波器的长度和未知系统长度相同,初始值为0 u = nlms( x, d, h, step_size ) return y, u, h ``` ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb6664a8.png) 系统识别(System Identification)的框图 此函数的参数分别为: * **nlms** : nlms算法的实现函数 * **x** : 参照信号 * **h0** : 未知系统的传递函数,虽然是未知系统,但是计算机模拟时它是已知的 * **step_size** : nlms算法的更新系数 * **noise_scale** : 外部干扰的系数,此系数决定外部干扰的大小,0表示没有外部干扰 函数的返回值分别为: * **y** : 未知系统的输出,不包括外部干扰 * **u** : 自适应滤波器的输出 * **h** : 自适应滤波器的最终的系数 最后我们用下面的函数创建未知系统h0, 参照信号x,然后调用sim_system_identify函数得到结果并且绘图: ``` def system_identify_test1(): h0 = make_path(32, 256) # 随机产生一个未知系统的传递函数 x = np.random.standard_normal(10000) # 参照信号为白噪声 y, u, h = sim_system_identify(nlms_numpy.nlms, x, h0, 0.5, 0.1) print diff_db(h0, h) pl.figure( figsize=(8, 6) ) pl.subplot(211) pl.subplots_adjust(hspace=0.4) pl.plot(h0, c="r") pl.plot(h, c="b") pl.title(u"未知系统和收敛后的滤波器的系数比较") pl.subplot(212) plot_converge(y, u) pl.title(u"自适应滤波器收敛特性") pl.xlabel("Iterations (samples)") pl.ylabel("Converge Level (dB)") pl.show() ``` ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb721451.png) 自适应滤波器收敛之后的系数和收敛速度 上部的图显示的是未知系统(红色)和自适应滤波器(蓝色)的传递函数的系数,我们看到自适应滤波器已经十分接近未知系统了。diff_db(h0, h)的输出为-25.35dB。下部的图通过绘制y和u之间的误差,显示了自适应滤波器的收敛过程。我们看到经过约3000点的计算之后,收敛过程已经饱和,最终的误差为-25dB左右,和diff_db计算的结果一致。 从图中可以看到收敛过程的两个重要特性:收敛时间和收敛精度。参照信号的特性、外部干扰的大小和更新系数都会影响这两个特性。下面让我们看看参照信号为白色噪声、外部干扰的能量固定时,更新系数对它们影响: ``` def system_identify_test2(): h0 = make_path(32, 256) # 随机产生一个未知系统的传递函数 x = np.random.standard_normal(20000) # 参照信号为白噪声 pl.figure(figsize=(8,4)) for step_size in np.arange(0.1, 1.0, 0.2): y, u, h = sim_system_identify(nlms_numpy.nlms, x, h0, step_size, 0.1) plot_converge(y, u, label=u"μ=%s" % step_size) pl.title(u"更新系数和收敛特性的关系") pl.xlabel("Iterations (samples)") pl.ylabel("Converge Level (dB)") pl.legend() pl.show() ``` ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb738e76.png) 更新系数和收敛速度的关系 下面是更新系数固定,外部干扰能量变化时的收敛特性: ``` def system_identify_test3(): h0 = make_path(32, 256) # 随机产生一个未知系统的传递函数 x = np.random.standard_normal(20000) # 参照信号为白噪声 pl.figure(figsize=(8,4)) for noise_scale in [0.05, 0.1, 0.2, 0.4, 0.8]: y, u, h = sim_system_identify(nlms_numpy.nlms, x, h0, 0.5, noise_scale) plot_converge(y, u, label=u"noise=%s" % noise_scale) pl.title(u"外部干扰和收敛特性的关系") pl.xlabel("Iterations (samples)") pl.ylabel("Converge Level (dB)") pl.legend() pl.show() ``` ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb753d6b.png) 外部干扰噪声和收敛速度的关系 从上面的图可以看出,当外部干扰的振幅增加一倍、能能量增加6dB时,收敛精度降低6dB。而由于更新系数相同,所以收敛过程中的收敛速度都是一样的。 ### 信号均衡模拟 对于信号均衡的应用我们用如下的程序进行模拟: ``` def sim_signal_equation(nlms, x, h0, D, step_size, noise_scale): d = x[:-D] x = x[D:] y = np.convolve(x, h0)[:len(x)] h = np.zeros(2*len(h0)+2*D, np.float64) y += np.random.standard_normal(len(y)) * noise_scale u = nlms(y, d, h, step_size) return h ``` ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb681a4b.png) 信号均衡(Equalization)框图 sim_signal_equation函数的参数: * **nlms** : nlms算法的实现函数 * **x** : 未知系统的输入信号 * **h0** : 未知系统的传递函数 * **D** : 延迟器的延时参数 * **step_size** : nlms算法的更新系数 * **noise_scale** : 外部干扰的系数,此系数决定外部干扰的大小,0表示没有外部干扰 在函数中的各个局部变量: * **d** : 输入信号经过延迟器之后的信号 * **y** : 未知系统的输出 * **h** : 自适应滤波器的系数,它的长度要足够长,程序中使用 2倍延时 + 2倍未知系统的传递函数的长度 函数的返回值为自适应滤波器收敛后的系数,它能够均衡h0对输入信号所造成的影响。我们通过下面的函数产生数据、调用模拟函数以及绘制结果: ``` def signal_equation_test1(): h0 = make_path(5, 64) D = 128 length = 20000 data = np.random.standard_normal(length+D) h = sim_signal_equation(nlms_numpy.nlms, data, h0, D, 0.5, 0.1) pl.figure(figsize=(8,4)) pl.plot(h0, label=u"未知系统") pl.plot(h, label=u"自适应滤波器") pl.plot(np.convolve(h0, h), label=u"二者卷积") pl.title(u"信号均衡演示") pl.legend() w0, H0 = scipy.signal.freqz(h0, worN = 1000) w, H = scipy.signal.freqz(h, worN = 1000) pl.figure(figsize=(8,4)) pl.plot(w0, 20*np.log10(np.abs(H0)), w, 20*np.log10(np.abs(H))) pl.title(u"未知系统和自适应滤波器的振幅特性") pl.xlabel(u"圆频率") pl.ylabel(u"振幅(dB)") pl.show() ``` 如果延迟器的延时D不够的话,会由于因果律使得自适应滤波器无法收敛。因此这里我们采用的D的长度为h0的长度的2倍。下图显示h0, h和它们的卷积。我们看到h0和h的卷积正好是一个脉冲,其延时为正好等于D(128)。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb773ba8.png) 未知系统和自适应滤波器的级联(卷积)近似为标准延迟 下图显示未知系统的频率响应(蓝色)和自适应滤波器的频率响应(绿色),我们看到二者正好相反,也就是说自适应滤波器均衡了未知系统对信号的影响。 ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb7874a9.png) 未知系统和自适应滤波器的频率响应正好相反 ### 卷积逆运算 虽然卷积运算最终能归结为简单的加法和乘法运算,然而卷积的逆运算就不是很容易计算了。我们知道两个线性系统h1和h2的级联h3可以用它们的脉冲响应的卷积计算求得,而所谓卷积的逆运算可以想象为已知h3和h1,求一个h2使它和h1级联之后正好等于h3。 根据卷积的计算公式可知,如果h1的长度为100,h3的长度为199,那么h2的长度则为100,因为h2的每个系数都是未知的,于是就有100个未知数,而这100个未知数需要满足199个线性方程:h3中的每个系数都有一个方程与之对应。由于方程数大于未知数的个数,显然对于任意的h1和h3并不能保证有一个h2使得它和h1的卷积正好等于h3。 既然不能精确求解,那么卷积的逆运算就变成了一个误差最小化的优化问题。用自适应滤波器计算卷积的逆运算和计算信号均衡类似,将白色噪声x输入到h1中得到信号u,将x输入到h3中得到信号d,然后使用u作为参照信号,d作为目标信号进行NLMS计算,最终收敛后的自适应滤波器的系数就是h2。 下面的程序模拟这一过程: ``` # -*- coding: utf-8 -*- import numpy as np import pylab as pl from nlms_numpy import nlms import scipy.signal as signal def inv_convolve(h1, h3, length): x = np.random.standard_normal(10000) u = signal.lfilter(h1, 1, x) d = signal.lfilter(h3, 1, x) h = np.zeros(length, np.float64) nlms(u, d, h, 0.1) return h h1 = np.fromfile("h1.txt", sep="\n") h1 /= np.max(h1) h3 = np.fromfile("h3.txt", sep="\n") h3 /= np.max(h3) pl.rc('legend', fontsize=10) pl.subplot(411) pl.plot(h3, label="h3") pl.plot(h1, label="h1") pl.legend() pl.gca().set_yticklabels([]) for idx, length in enumerate([128, 256, 512]): pl.subplot(412+idx) h2 = inv_convolve(h1, h3, length) pl.plot(np.convolve(h1, h2)[:len(h3)], label="h1*h2(%s)" % length) pl.legend() pl.gca().set_yticklabels([]) pl.gca().set_xticklabels([]) pl.show() ``` 下面是程序的计算结果: ![](https://docs.gechiui.com/gc-content/uploads/sites/kancloud/2016-03-19_56ed1bb7a0103.png) 卷积逆运算演示 程序中的h1和h3从文本文件中读取而得,它们是ANC(能动噪声控制)系统中实际测量的脉冲响应。如果能找到一个h2满足卷积条件的话,就能够有效的进行噪声控制。 程序计算出h2的长度分别为128, 256, 512时的结果,可以看出h2越长结果越精确。 ## DLL函数的编写 ## ctypes的python接口
';

Ctypes和NumPy

最后更新于:2022-04-01 11:15:39

# Ctypes和NumPy ## 用ctypes加速计算 Ctypes是Python处理动态链接库的标准扩展模块,在Windows下使用它可以直接调用C语言编写的DLL动态链接库。由于对传递的参数没有类型和越界检查,因此如果编写的代码有问题的话,很可能会造成程序崩溃。当将数组数据使用指针传递时,出错误的风险将更加大。 为了让程序更加安全,通常会用Python代码对Ctypes调用进行包装,在调用Ctypes之前,在Python级别对数据类型和越界进行检查。这样做会使得调用接口部分比其它的一些手工编写的扩展模块速度要慢,但是如果C语言的代码段处理相当多的数据的话,接口调用部分的速度损失是可以忽略不计的。 ## 用ctypes调用DLL 为了使用CTypes,你必须依次完成以下步骤: * 编写动态连接库程序 * 载入动态连接库 * 将Python的对象转换为ctypes所能识别的参数 * 使用ctypes的参数调用动态连接库中的函数 下面我们来看看如何用ctypes调用动态链接库。 ## numpy对ctypes的支持 为了方便动态连接库的载入,numpy提供了一个便捷函数ctypeslib.load_library。它有两个参数,第一个参数是库的文件名,第二个参数是库所在的路径。函数返回的是一个ctypes的对象。通过此对象的属性可以直接到动态连接库所提供的函数。 例如如果我们有一个库名为test_sum.dll,其中提供了一个函数mysum : ``` double mysum(double a[], long n) { double sum = 0; int i; for(i=0;i<n;i++) sum += a[i]; return sum; } ``` 的话,我们可以使用如下语句载入此库: ``` >>> from ctypes import * >>> sum_test = np.ctypeslib.load_library("sum_test", ".") >>> print sum_test.mysum <_FuncPtr object at 0x037D7210> ``` 要正确调用sum函数,还必须对其参数类型进行说明,下面的语句描述了sum函数的两个参数的类型和返回值的类型进行描述: ``` >>> sum_test.mysum.argtypes = [POINTER(c_double), c_long] >>> sum_test.mysum.restype = c_double ``` 接下来就可以正常调用sum函数了: ``` >>> x = np.arange(1, 101, 1.0) >>> sum_test.mysum(x.ctypes.data_as(POINTER(c_double)), len(x)) 5050.0 ``` 每次调用sum都需要进行类型转换时比较麻烦的事情,因此可以编写一个Python的mysum函数,将C语言的mysum函数包装起来: ``` def mysum(x): return sum_test.mysum(x.ctypes.data_as(POINTER(c_double)), len(x)) ``` 在上面的例子中,test_sum.mysum的参数值使用标准的ctypes类型声明:用POINTER(c_double)声明mysum函数的第一个参数是一个指向double的指针;然后调用数组x的x.ctypes.data_as函数将x转换为一个指向double的指针类型。 由于数组的元素在内存中的存储可以是不连续的,而且可以是多维数组,因此我们不能指望前面的mysum函数能够处理所有的情况: ``` >>> x = np.arange(1,11,1.0) >>> mysum(x[::2]) 15.0 >>> sum(x[::2]) 25.0 ``` 由于x[::2]和x共同一块内存空间,而x[::2]中的元素是不连续的,每个元素之间的间隔为16byptes(2个double的大小)。因此将它传递给mysum的话,实际上计算的是x数组中前5项的和:1+2+3+4+5=15,而实际上我们希望的结果是:1+3+5+7+9=25。 为了对传递的数组参数进行更加详细的描述,numpy库提供了ndpointer函数。ndpointer函数对restype和argtypes中的数组参数进行描述,他有如下4个参数: * **dtype** : 数组的元素类型 * **ndim** : 数组的维数 * **shape** : 数组的形状,各个轴的长度 * **flags** : 数组的标志 例如: ``` test_sum.mysum.argtypes = [ np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags="C_CONTIGUOUS"), c_long ] ``` 描述了sumfunc函数的参数为一个元素类型为double的、一维的、连续的元素按C语言规定排列的数组。 这时传递给mysum函数的第一个参数可以直接是数组,因此无需再编写一个Python函数对其进行包装: ``` >>> sum_test.mysum(x,len(x)) 55.0 >>> sum_test.mysum(x[::2],len(x)/2) ArgumentError: argument 1: <type 'exceptions.TypeError'>: array must have flags ['C_CONTIGUOUS'] ``` 我们注意到如果参数数组不是连续空间的话,mysum函数的调用会抛出异常错误,提醒我们其参数需要C语言排列的连续数组。 如果我们希望它能够处理多维、不连续的数组的话,就需要把数组的shape和strides属性都传递给过去。假设我们想写一个通用的mysum2函数,它可以对二维数组的所有元素进行求和。下面是C语言的程序: ``` double mysum2(double a[], int strides[], int shapes[]) { double sum = 0; int i, j, M, N, S0, S1; M = shape[0]; N=shape[1]; S0 = strides[0] / sizeof(double); S1 = strides[1] / sizeof(double); for(i=0;i<M;i++){ for(j=0;j<N;j++){ sum += a[i*S0 + j*S1]; } } return sum; } ``` mysum2函数有3个参数,第一个参数a[]指向保存数组数据的内存块;第二个参数astrides指向保存数组各个轴元素之间的间隔(以byte为单位);第三个参数dims指向保存数组各个轴长度的数组。 由于strides保存的是以byte为单位的间隔长度,因此需要除以sizeof(double)计算出以double为单位的间隔长度S0和S1。这样二维数组a中的第i行、第j列的元素可以通过a[i*S0 + j*S1]来存取。下面用ctypes对mysum2函数进行包装: ``` sum_test.mysum2.restype = c_double sum_test.mysum2.argtypes = [ np.ctypeslib.ndpointer(dtype=np.float64, ndim=2), POINTER(c_int), POINTER(c_int) ] def mysum2(x): return sum_test.mysum2(x, x.ctypes.strides, x.ctypes.shape) ``` 在mysum2函数中,为了将数组x的strides和shape属性传递给C语言的函数,可以使用x.ctypes中提供的strides和shape属性。注意不能直接传递x.strides和x.shape,因为这些是python的tuple对象,而x.ctypes.shape得到的是ctypes包装的整数数组: ``` >>> x = np.zeros((3,4), np.float) >>> x.ctypes.shape <numpy.core._internal.c_long_Array_2 object at 0x020B4DF0> >>> s = x.ctypes.shape >>> s[0] 3 >>> s[1] 4 ``` 可以看出x.ctypes.shape是一个有两个元素的C语言长整型数组。虽然我们也可以在Python中通过下标读取其各个元素的值,但是通常它们是作为参数传递给C语言函数用的。
';