Example: ptychographic reconstruction with position optimisations

This example uses simulated data, and tracks the positions changes vs the optimisation cycles

[1]:
#import os
#os.environ['PYNX_PU'] = 'opencl'  # Select language and/or GPU name or rank through environment variable (optional)

%matplotlib notebook
import matplotlib.pyplot as plt
from pynx.ptycho import simulation, shape

# Import Ptycho, PtychoData and operators (automatically selecting OpenCL or CUDA)
from pynx.ptycho import *

Simulate the Ptychographic data

[2]:
n = 256
nb_frame = 120
pixel_size_detector = 55e-6
wavelength = 1.5e-10
detector_distance = 1
obj_info = {'type': 'phase_ampl', 'phase_stretch': np.pi, 'ampl_range': (0.8,1.2), 'alpha_win': .2}
probe_info = {'type': 'focus', 'aperture': (120e-6, 120e-6), 'focal_length': .08, 'defocus': 800e-6, 'shape': (n, n)}

pixel_size_object = wavelength * detector_distance / pixel_size_detector / n

# 50 scan positions correspond to 4 turns, 78 to 5 turns, 113 to 6 turns
scan_info = {'type': 'spiral', 'scan_step_pix': 30, 'n_scans': nb_frame, 'integer_values': False}

if False:
    # Use last frame without sample (direct beam ; serves as an absolute reference)
    s = simulation.Simulation(obj_info=None, probe_info=None, scan_info=scan_info, data_info=None, verbose=False)
    s.make_scan()
    posx, posy = s.scan.values
    posx[-1] = 1e20
    posy[-1] = 1e20
    posx = np.ma.masked_array(posx, posx>=1e10)
    posy = np.ma.masked_array(posy, posy>=1e10)
    scan_info = {'type': 'custom', 'x': posx, 'y': posy}

data_info = {'num_phot_max': 1e9, 'bg': 0, 'wavelength': wavelength, 'detector_distance': detector_distance,
             'detector_pixel_size': pixel_size_detector,
             'noise': 'poisson'}

# Initialisation of the simulation with specified parameters, specific <object>, <probe> or <scan>positions can be passed as:
# s = ptycho.Simulation(obj=<object>, probe=<probe>, scan = <scan>)
# omitting obj_info, probe_info or scan_info (or passing it as empty dictionary "{}")
s = simulation.Simulation(obj_info=obj_info, probe_info=probe_info, scan_info=scan_info, data_info=data_info)

# Data simulation: probe.show(), obj.show(), scan.show() and s.show_illumination_sum() will visualise the integrated total coverage of the beam
s.make_data()

posx, posy = s.scan.values

ampl = s.amplitude.values  # square root of the measured diffraction pattern intensity
Simulating object: phase_ampl
Simulating probe: focus
WARNING: exceeding maximum near field propagation distance: z=0.000800 > 0.000194
Simulating scan: spiral
Simulating ptychographic data [120 frames].


Parameters of the simulation:
Data info: {'pix_size_direct_nm': 10, 'num_phot_max': 1000000000.0, 'nb_photons_per_frame': 100000000.0, 'bg': 0, 'beam_stop_transparency': 0, 'noise': 'poisson', 'wavelength': 1.5e-10, 'detector_distance': 1, 'detector_pixel_size': 5.5e-05}
Scan info: {'type': 'spiral', 'scan_step_pix': 30, 'n_scans': 120, 'integer_values': False}
Object info: {'type': 'Custom', 'phase_stretch': 3.141592653589793, 'ampl_range': (0.8, 1.2), 'alpha_win': 0.2}
Probe info: {'type': 'focus', 'shape': (256, 256), 'sigma_pix': (50, 50), 'rotation': 0, 'aperture': (0.00012, 0.00012), 'focal_length': 0.08, 'defocus': 0.0008}

Create the initial reconstructed object & probe

The initial object is random (amplitude between 0.5 and 1, phase between 0 and 0.5 radians), and the probe is different from the one used to simulate the diffraction patterns.

[3]:
nyo, nxo = shape.calc_obj_shape(posx, posy, ampl.shape[1:])

# Initial object
# obj_init_info = {'type':'flat','shape':(nx,ny)}
obj_init_info = {'type': 'random', 'range': (0.5, 1, 0, 0.5), 'shape': (nyo, nxo)}
# Initial probe
probe_init_info = {'type': 'focus', 'aperture': (100e-6, 100e-6), 'focal_length': .08,
                   'defocus': 700e-6, 'shape': (n, n)}
data_info = {'wavelength': wavelength, 'detector_distance': detector_distance,
             'detector_pixel_size': pixel_size_detector}
init = simulation.Simulation(obj_info=obj_init_info, probe_info=probe_init_info, data_info=data_info)

init.make_obj()
init.make_probe()
Simulating object: random
Simulating probe: focus
WARNING: exceeding maximum near field propagation distance: z=0.000700 > 0.000194

Alter the positions & create the Ptycho object

We just alter two positions here, but more can be added.

The p._interpolation parameter can be used to trigger the use of interpolation - i.e. when a scan position does not correspond to an integer number of pixels, the object is interpolated with a bilinear approximation. Note that it is not necessary to use this interpolation to detetermine the

[4]:
posx1, posy1 = posx.copy(), posy.copy()
posx1[10] += 5
posy1[10] += 10
posx1[20] -= 5
posy1[20] -= 5

if False:
    posx1, posy1 = np.round(posx1), np.round(posy1)

data = PtychoData(iobs=ampl ** 2, positions=(posx1 * pixel_size_object, posy1 * pixel_size_object),
                  detector_distance=1, mask=None, pixel_size_detector=55e-6, wavelength=1.5e-10)

p = Ptycho(probe=s.probe.values, obj=init.obj.values, data=data, background=None) # Random object start


# Use interpolation ?
p._interpolation = False

# Initial scaling of object and probe
p = ScaleObjProbe(verbose=True) * p
ScaleObjProbe: 4342.861 295804.06 22119.73934359367 1227.4508294961252 18.020875979504595

Initial object and probe optimisation

[5]:
plt.figure()
p = DM(update_object=True, update_probe=True, calc_llk=20, show_obj_probe=20)**40 * p
p = AP(update_object=True, update_probe=False, calc_llk=20, show_obj_probe=20)**40 * p
p = ML(update_object=True, update_probe=True, calc_llk=20, show_obj_probe=20)**40 * p
DM/o/p     #  0 LLK= 292753.77(p) 565050436.27(g) 464747.43(e), nb photons=2.371754e+13, dt/cycle=0.117s
DM/o/p     # 20 LLK= 17818.96(p) 2186092.27(g) 27596.44(e), nb photons=2.475301e+13, dt/cycle=0.026s
DM/o/p     # 39 LLK= 23930.98(p) 2355334.40(g) 36488.55(e), nb photons=2.454958e+13, dt/cycle=0.026s
AP/o       # 40 LLK= 23931.00(p) 2355332.53(g) 36488.58(e), nb photons=2.454957e+13, dt/cycle=0.218s
AP/o       # 60 LLK=  7552.97(p) 4138536.53(g) 14960.14(e), nb photons=2.491395e+13, dt/cycle=0.018s
AP/o       # 79 LLK=  6991.36(p) 3808269.07(g) 13995.02(e), nb photons=2.491575e+13, dt/cycle=0.019s
ML/o/p     # 81 LLK=  6590.69(p) 3506709.87(g) 13189.20(e), nb photons=2.491670e+13, dt/cycle=0.104s
ML/o/p     #101 LLK=  6114.22(p) 4837075.20(g) 12981.71(e), nb photons=2.494196e+13, dt/cycle=0.037s
ML/o/p     #120 LLK=  6022.78(p) 4885213.33(g) 12855.35(e), nb photons=2.494200e+13, dt/cycle=0.038s

Optimise positions

This works best using AP or ML algorithms. DM tends to be more unstable.

We use the pos_history option so that we can plot the position history vs the cycle number later. It slows down the optimisation as data needs to be retreived from the GPU for each cycle.

[6]:
plt.figure()  # Use a new figure
#p = ShowObjProbe() *DM(update_object=True, update_probe=True, update_pos=True,  pos_threshold=0.1,
#                       pos_min_shift=0.0, pos_max_shift=2, pos_history=True, calc_llk=20,
#                       show_obj_probe=20)**100 * p
p = ShowObjProbe() *AP(update_object=True, update_probe=True, update_pos=5, pos_mult=5, pos_threshold=0.2,
                       pos_min_shift=0.0, pos_max_shift=2, pos_history=True, calc_llk=50,
                       show_obj_probe=50)**500 * p
p = ShowObjProbe() * ML(update_object=True, update_probe=True, update_pos=5,
                        pos_history=True, calc_llk=20, show_obj_probe=20)**100 * p

AP/o/p/t   #121 LLK=  6022.65(p) 4885588.80(g) 12855.25(e), nb photons=2.494203e+13, dt/cycle=0.279s
AP/o/p/t   #171 LLK=  2904.48(p) 2463835.20(g)  6146.42(e), nb photons=2.493533e+13, dt/cycle=0.022s
AP/o/p/t   #221 LLK=  2546.17(p) 2125552.27(g)  5465.35(e), nb photons=2.493574e+13, dt/cycle=0.021s
AP/o/p/t   #271 LLK=  2644.06(p) 2094336.80(g)  5672.60(e), nb photons=2.493343e+13, dt/cycle=0.021s
AP/o/p/t   #321 LLK=  2325.79(p) 1566365.47(g)  5048.89(e), nb photons=2.493411e+13, dt/cycle=0.021s
AP/o/p/t   #371 LLK=  2180.36(p) 1535412.40(g)  4762.99(e), nb photons=2.493360e+13, dt/cycle=0.022s
AP/o/p/t   #421 LLK=  2485.80(p) 1538936.27(g)  5374.69(e), nb photons=2.492999e+13, dt/cycle=0.021s
AP/o/p/t   #471 LLK=  2315.50(p) 1528121.60(g)  5037.12(e), nb photons=2.493119e+13, dt/cycle=0.022s
AP/o/p/t   #521 LLK=  2461.21(p) 1530101.60(g)  5332.94(e), nb photons=2.493124e+13, dt/cycle=0.022s
AP/o/p/t   #571 LLK=  2343.66(p) 1522273.47(g)  5098.71(e), nb photons=2.493093e+13, dt/cycle=0.021s
AP/o/p/t   #620 LLK=  2324.34(p) 1505121.07(g)  5061.07(e), nb photons=2.493170e+13, dt/cycle=0.022s
ML/o/p/t   #622 LLK=  2118.93(p) 1466608.40(g)  4653.07(e), nb photons=2.493149e+13, dt/cycle=0.173s
ML/o/p/t   #642 LLK=  2290.70(p) 2042786.53(g)  5269.61(e), nb photons=2.493727e+13, dt/cycle=0.044s
ML/o/p/t   #662 LLK=  2525.50(p) 2170694.93(g)  5786.08(e), nb photons=2.493879e+13, dt/cycle=0.044s
ML/o/p/t   #682 LLK=  2628.20(p) 2234643.20(g)  6013.04(e), nb photons=2.494799e+13, dt/cycle=0.046s
ML/o/p/t   #702 LLK=  2782.79(p) 3060746.93(g)  6351.56(e), nb photons=2.495193e+13, dt/cycle=0.044s
ML/o/p/t   #721 LLK=  2537.53(p) 2239914.13(g)  5856.26(e), nb photons=2.494376e+13, dt/cycle=0.045s

Plot the position shifts

The recorded position shifts can be manually plotted

[7]:
ipos = [10,20]  # 10 or 20
fig = plt.figure(figsize=(9.5,4))
ax = plt.subplot(121)
#for i in range(default_processing_unit.get_stack_size()):
for i in range(50):
    x = [v[1] for v in p.position_history[i]]
    y = [v[2] for v in p.position_history[i]]
    plt.scatter(x,y, 1)
    plt.text(x[0], y[0], '%d' % i)
    #print("%3d  dr = %5.3f" % (i, np.sqrt((x[0]-x[-1])**2 + (y[0]-y[-1])**2)))
ax.set_aspect(1)

for i in range(len(ipos)):
    plt.subplot(2,2,2 + 2 *i)
    ix, x, y = [v[0] for v in p.position_history[ipos[i]]], [v[1] for v in p.position_history[ipos[i]]], \
               [v[2] for v in p.position_history[ipos[i]]]
    plt.plot(ix,x,'b.', label='x[%d]'%ipos[i])
    plt.xlabel('cycle')
    plt.twinx()
    plt.plot(ix,y,'r.', label='y[%d]'%ipos[i])
    fig.legend(loc="center right")
plt.tight_layout()
[ ]: