Interactive widget to simulate a focused wavefront#

[1]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from pynx.wavefront import *

import ipywidgets as widgets

if False:
    # Wider display ?
    from IPython.core.display import display, HTML
    display(HTML("<style>.container { width:100% !important; }</style>"))

[2]:
class WidgetSimulFocus(widgets.Box):
    def __init__(self):
        super(WidgetSimulFocus, self).__init__()

        focus_label = widgets.Label(value='Focal distance (cm):')
        self.focus = widgets.FloatSlider(value=10, min=1,max=50,step=0.1,
                                                    disabled=False, continuous_update=False, orientation='horizontal',
                                                    readout=True, readout_format='.1f')
        defocus_label = widgets.Label(value='Defocus (µm):')
        self.defocus = widgets.FloatSlider(value=0, min=-2000,max=2000,step=50,
                                     disabled=False, continuous_update=False, orientation='horizontal',
                                      readout=True, readout_format='.1f')

        self.display_type = widgets.RadioButtons(options=['RGBA', 'Amplitude', 'Phase'],
                                                 value='Amplitude', orientation='horizontal', disabled=False)

        aperture_label = widgets.Label(value='Aperture (µm):')
        self.aperture = widgets.FloatSlider(value=200, min=40,max=500,step=20,
                                            disabled=False, continuous_update=False, orientation='horizontal',
                                            readout=True, readout_format='.1f')

        self.aperture_type = widgets.RadioButtons(options=['Circle', 'Square'],
                                                  value='Circle', orientation='horizontal', disabled=False)

        pixel_label = widgets.Label(value='Pixel size @aperture (µm):')
        self.pixel = widgets.FloatSlider(value=2, min=0.2,max=5,step=0.1,
                                         disabled=False, continuous_update=False, orientation='horizontal',
                                         readout=True, readout_format='.1f')

        wsize_label = widgets.Label(value='Array size:')
        self.wsize = widgets.FloatLogSlider(value=512,base=2, min=7,max=12,step=1,
                                            disabled=False, continuous_update=False, orientation='horizontal',
                                            readout=True, readout_format='.0f')

        nrj_label = widgets.Label(value='X-ray energy (keV):')
        self.nrj_kev = widgets.FloatSlider(value=10, min=1,max=40,step=0.5,
                                         disabled=False, continuous_update=False, orientation='horizontal',
                                         readout=True, readout_format='.1f')

        display_location_label = widgets.Label(value='Plot wavefront at:')
        self.display_location = widgets.RadioButtons(options=['aperture (before propagation)', 'focus/defocus'],
                                          value='focus/defocus', orientation='horizontal', disabled=False)

        vbox = widgets.VBox([focus_label, self.focus, defocus_label, self.defocus,self.display_type,
                             aperture_label, self.aperture, self.aperture_type,
                             pixel_label, self.pixel, wsize_label, self.wsize, nrj_label, self.nrj_kev,
                            display_location_label, self.display_location])

        output_fig = widgets.Output()
        with output_fig:
            self.fig = plt.figure()
            self.fig.canvas.header_visible=False  # Hide fig num

        self.focus.observe(self.plot)
        self.defocus.observe(self.plot)
        self.display_type.observe(self.plot)
        self.aperture.observe(self.plot)
        self.aperture_type.observe(self.plot)
        self.display_location.observe(self.plot)

        self.pixel.observe(self.init_wavefront)
        self.wsize.observe(self.init_wavefront)
        self.nrj_kev.observe(self.init_wavefront)

        self.children = [widgets.HBox([output_fig, vbox])]

        self.last_plot_params = None
        self.init_wavefront(plot=True)

    def init_wavefront(self, plot=False):
        n = int(self.wsize.value)
        pix = self.pixel.value * 1e-6
        wav = 12.3984e-10 / self.nrj_kev.value
        self.w = Wavefront(d=np.ones((n, n), dtype=np.complex64), pixel_size=pix, wavelength=wav)
        if plot:
            self.plot()

    def plot(self,k=None, force_plot=False):
        w = self.w
        wav = 12.3984e-10 / self.nrj_kev.value
        n = int(self.wsize.value)
        pix = self.pixel.value * 1e-6
        plot_params=[n,self.focus.value, self.defocus.value, self.aperture.value,
                     self.aperture_type.value, self.pixel.value, self.nrj_kev.value, self.display_type.value,
                     self.display_location.value]
        if plot_params == self.last_plot_params or force_plot:
            return
        # print(plot_params)
        w.set(np.ones((n,n), dtype=np.complex64))
        w.z = 0
        w.pixel_size = pix
        if self.aperture_type.value == 'Square':
            w = RectangularMask(width=self.aperture.value*1e-6, height=self.aperture.value*1e-6) * w
        else:
            w = CircularMask(radius=self.aperture.value*1e-6/2) * w

        if 'focus' in self.display_location.value:
            w = PropagateFarField(self.focus.value*1e-2, forward=False) * w
            w = PropagateNearField(self.defocus.value*1e-6) * w
            tit = "f=%6.2fcm defocus=%5.0fµm" % (self.focus.value, self.defocus.value)
        else:
            tit = "Aperture: %s, size=%4.0fµm" % (self.aperture_type.value, self.aperture.value)

        if self.display_type.value =='RGBA':
            w = ImshowRGBA(title=tit, fig_num=self.fig.number, colorwheel=False) * w
        elif self.display_type.value =='Amplitude':
            w = ImshowAbs(title=tit, fig_num=self.fig.number) * w
        elif self.display_type.value =='Phase':
            w = ImshowAngle(title=tit, fig_num=self.fig.number) * w

        self.last_plot_params= plot_params

WidgetSimulFocus()

[2]:
[ ]: