Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scopesim/effects/apertures.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def apply_to(self, obj, **kwargs):
for vol in obj.volumes:
vol["meta"]["xi_min"] = min(x) * u.arcsec
vol["meta"]["xi_max"] = max(x) * u.arcsec
vol["meta"]["slit_name"] = self.meta["name"]

return obj

Expand Down
34 changes: 20 additions & 14 deletions scopesim/effects/selector_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,12 @@ class SelectorWheel(Effect):
"""

z_order = (290, 690, 890)
z_order = ()
required_keys = {"selector_key", "wheel"}

def __init__(self, **kwargs):
super().__init__(**kwargs)
check_keys(kwargs, self.required_keys, action="error")

self.meta.update(kwargs)
super().__init__(**kwargs)

self.wheel_effects = {}
for wheel_entry in self.meta["wheel"]:
Expand All @@ -70,9 +68,12 @@ def __init__(self, **kwargs):
# Instantiate the effect and store it in the wheel_effects dictionary
if isinstance(selector_value, list):
for val in selector_value:
self.wheel_effects[val] = effect_class(**effect_kwargs)
self.wheel_effects[val] = effect_class(cmds=self.cmds, **effect_kwargs)
else:
self.wheel_effects[selector_value] = effect_class(**effect_kwargs)
self.wheel_effects[selector_value] = effect_class(cmds=self.cmds, **effect_kwargs)

# Use the wheel effects' z_order as the z_order of the selector wheel
self.z_order = [eff.z_order for eff in self.wheel_effects.values()][0]


def apply_to(self, obj, **kwargs):
Expand All @@ -95,15 +96,19 @@ def apply_to(self, obj, **kwargs):

for val in unique_selector_values:
vols_with_val = [vol for vol in obj.volumes if vol["meta"].get(self.meta["selector_key"], None) == val]
effect_to_apply = self.get_effect(val)
logger.info(f"Applying effect for selector value: {val} -> {effect_to_apply}, volumes: {len(vols_with_val)}")

if val is None or effect_to_apply is None:
if val is None:
logger.warning(f"Volume(s) with missing selector key {self.meta['selector_key']} value found, "
f"applying no effect to those volumes.")
new_volumes.extend(vols_with_val)
continue

#TODO: If effect_to_apply is a dichroic which reassigns selector_key values, we need to add a check here
#TODO: i.e. dichroic.arm_action.keys() should not include items in unique_selector_values other than val
effect_to_apply = self.get_effect(val)
logger.debug(f"Applying effect for {self.meta['selector_key']}: {val} -> {effect_to_apply}, volumes: {len(vols_with_val)}")

if effect_to_apply is None:
new_volumes.extend(vols_with_val)
continue

newvollist = FovVolumeList()
newvollist.volumes = vols_with_val
Expand All @@ -113,7 +118,7 @@ def apply_to(self, obj, **kwargs):
obj.volumes = new_volumes

if isinstance(obj, Detector):
logger.info("Since passed object is a Detector, selector_key by default is the ID of the Detector object.")
logger.debug("Since passed object is a Detector, selector_key by default is the ID of the Detector object.")
selector_value = obj.meta[real_colname("id", obj.meta)] # Assuming detector ID is the selector

effect_to_apply = self.get_effect(selector_value)
Expand All @@ -128,8 +133,9 @@ def apply_to(self, obj, **kwargs):

def get_effect(self, selector_value):
eff = None
if (selector_value is None) or (selector_value not in self.wheel_effects.keys()):
logger.warning(f"Either None or Missing value of {self.meta["selector_key"]} requested: {selector_value}, returning None.")
if selector_value not in self.wheel_effects.keys():
logger.warning(f"Entry for selector value {selector_value} not found in wheel effects. "
f"Assuming no effect to apply for this selector value.")
else:
eff = self.wheel_effects[selector_value]
return eff
Expand Down
90 changes: 88 additions & 2 deletions scopesim/effects/spectral_efficiency.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""Spectral grating efficiencies."""

from typing import ClassVar
from typing import ClassVar, Callable

import numpy as np
from astropy.io import fits
Expand All @@ -12,7 +12,9 @@
from .effects import Effect
from .ter_curves import TERCurve
from .ter_curves_utils import apply_throughput_to_cube
from ..utils import figure_factory, get_logger
from ..utils import figure_factory, get_logger, check_keys
from .data_container import DataContainer
from ..optics import echelle


logger = get_logger(__name__)
Expand Down Expand Up @@ -127,3 +129,87 @@ def plot(self):
axes.legend()

return fig


class EchelleSpectralEfficiency(Effect):
"""
Spectral efficiency list from analytical calculations of the blaze function for ZShooter gratings.
Requires same input trace parameter table as EchelleSpectralTraceList, supply as kwarg "filename"
"""
z_order: ClassVar[tuple[int, ...]] = (630,)

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._spectrographs = None
self.efficiencies = {}
self.efficiency_generator = self._generate_efficiency_curve_func()

def _generate_efficiency_curve_func(self) -> Callable:
trace_params = self.table
spectrographs = {}
for row in trace_params:
prefix = row["prefix"] # note trance ids are assumed to be prefix_{order}
min_order = row['m0'] - row['n']
max_order = row['m0']
min_wave = row['min_wave'] * u.Unit(trace_params.meta["min_wave_unit"])
max_wave = row['max_wave'] * u.Unit(trace_params.meta["max_wave_unit"])
design_res = row['design_res']
focal_len = row['focal_length'] * u.Unit(trace_params.meta["focal_length_unit"])
disp_npix = row['n_disp'] - 2 * row['detector_pad']
xdisp_npix = row['n_xdisp']- 2 * row['detector_pad']
pix_size = row['pixel_size'] * u.Unit(trace_params.meta["pixel_size_unit"])
echelle_angle = np.deg2rad(row['echelle_blaze'])*u.rad
xdisp_beta_center = np.deg2rad(row['xbeta_center'])*u.rad

xdisp_groove_length = u.Unit(trace_params.meta["xdisp_freq_unit"]) / row['xdisp_freq']
echelle_groove_length = u.Unit(trace_params.meta["disp_freq_unit"]) / row['disp_freq']
pix_per_res_elem = row['fwhm']

spectrographs[prefix] = echelle.spectrograph_factory(min_wave, max_wave, focal_len,
design_res, echelle_angle, min_order, max_order,
echelle_groove_length, pix_per_res_elem, disp_npix, xdisp_npix,
pix_size, xdisp_groove_length=xdisp_groove_length,
xdisp_beta_center=xdisp_beta_center)
self._spectrographs = spectrographs

def efficiency_curve(trace_id, wavelength):
"""Trace ID MUST be in the form prefix_{order}"""
prefix, _, order = trace_id.partition('_')
order = int(order)
spec = spectrographs[prefix]
blaze = spec.grating.blaze(spec.grating.beta(wavelength, order), order)
xdisp = spec.xdisp_efficiency(wavelength)
return blaze*xdisp

return efficiency_curve

def apply_to(self, obj, **kwargs):
"""Interface between FieldOfView and SpectralEfficiency."""
trace_id = obj.trace_id

swcs = WCS(obj.hdu.header).spectral
with u.set_enabled_equivalencies(u.spectral()):
wave = swcs.pixel_to_world(np.arange(swcs.pixel_shape[0])) << u.um

efficiency = self.efficiency_generator(trace_id, wave)
params = {"description": trace_id}
params.update(self.meta)
effic_curve = TERCurve(array_dict={"wavelength": wave, "transmission": efficiency}, **params)
self.efficiencies[trace_id] = effic_curve

obj.hdu = apply_throughput_to_cube(obj.hdu, effic_curve.throughput, wave)
return obj

def plot(self):
"""Plot the grating efficiencies."""
fig, axes = figure_factory()
for name, effic in self.efficiencies.items():
wave = effic.throughput.waveset
axes.plot(wave.to(u.um), effic.throughput(wave), label=name)

axes.set_xlabel("Wavelength [um]")
axes.set_ylabel("Grating efficiency")
axes.set_title(f"Grating efficiencies {self.display_name}")
axes.legend()

return fig
88 changes: 44 additions & 44 deletions scopesim/effects/spectral_trace_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
The Effect is called `SpectralTraceList`, it applies a list of
`spectral_trace_list_utils.SpectralTrace` objects to a `FieldOfView`.
"""

from itertools import cycle
from typing import ClassVar

Expand All @@ -15,14 +14,15 @@
from astropy.io import fits
from astropy.table import Table
import astropy.units as u
import os

from .effects import Effect
from .ter_curves import FilterCurve
from .spectral_trace_list_utils import SpectralTrace, make_image_interpolations
from ..optics.image_plane_utils import header_from_list_of_xy
from ..optics.fov import FieldOfView
from ..optics.fov_volume_list import FovVolumeList
from ..utils import from_currsys, check_keys, figure_factory, get_logger
from ..utils import from_currsys, check_keys, figure_factory, get_logger, from_rc_config
from .data_container import DataContainer
from ..optics import echelle

Expand Down Expand Up @@ -108,8 +108,8 @@ class SpectralTraceList(Effect):
report_plot_include: ClassVar[bool] = True
report_table_include: ClassVar[bool] = False

def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init__(self, cmds=None, **kwargs):
super().__init__(cmds=cmds, **kwargs)

if "hdulist" in kwargs and isinstance(kwargs["hdulist"], fits.HDUList):
self._file = kwargs["hdulist"]
Expand Down Expand Up @@ -137,7 +137,6 @@ def __init__(self, **kwargs):

if self._file is not None:
self.make_spectral_traces()

self.update_meta()

def make_spectral_traces(self):
Expand Down Expand Up @@ -524,7 +523,7 @@ class EchelleSpectralTraceList(SpectralTraceList):
instead of loading them from FITS file. The arguments required to define the echelle traces are supplied through
a txt file containing a table of parameters using the filename kwarg.
Below is an example of how to define the echelle trace parameters (see irdb/ZShooter/traces/echelle_trace_parameters.txt):
Below is an example of how to define the echelle trace parameters (see irdb/ZShooter_v1/traces/echelle_trace_parameters.txt):
----------------------------------------------------------------
# min_wave_unit : nm
# max_wave_unit : nm
Expand All @@ -539,10 +538,10 @@ class EchelleSpectralTraceList(SpectralTraceList):
# xdisp_freq_unit : mm
# slitwidth_unit : arcsec
prefix aperture_id image_plane_id m0 n min_wave max_wave echelle_blaze focal_length fwhm detector_pad pixel_size n_disp n_xdisp disp_freq xdisp_freq slitwidth dispdir plate_scale
nIR 0 0 40 24 970 2500 64.2 225 4.7 10 0.015 4096 4096 45 175 10 x 0.159574468085
gri 1 1 36 18 490 1020 64.2 225 4.7 10 0.015 4096 4096 100 500 10 x 0.159574468085
ub 2 2 29 11 315 515 64.2 225 4.7 10 0.015 4096 4096 200 1000 10 x 0.159574468085
prefix aperture_id image_plane_id m0 n min_wave max_wave design_res echelle_blaze focal_length fwhm detector_pad pixel_size n_disp n_xdisp disp_freq xdisp_freq slitwidth dispdir
ub 0 2 29 11 315 515 20000 64.2 225 4.7 10 0.015 4096 4096 200 1000 10 x
gri 1 1 36 18 490 1020 20000 64.2 225 4.7 10 0.015 4096 4096 100 500 10 x
nIR 2 0 40 24 970 2500 20000 64.2 225 4.7 10 0.015 4096 4096 45 175 10 x
----------------------------------------------------------------
The calculated traces are stored in the same HDUList format as required by SpectralTraceList,
Expand All @@ -552,13 +551,19 @@ class EchelleSpectralTraceList(SpectralTraceList):
required_keys = {"filename"}
z_order = (71, 271, 671)

def __init__(self, **kwargs):
def __init__(self, cmds=None, **kwargs):
check_keys(kwargs, self.required_keys, action="error")
self.cmds = cmds

trace_params = DataContainer(filename=kwargs['filename'])
trace_param_filename = kwargs.pop("filename")
trace_params = DataContainer(filename=trace_param_filename)
hdulist = self._generate_trace_hdulist(trace_params)
hdulist.writeto(f"{from_rc_config('!SIM.file.local_packages_path')}/"
f"{from_currsys('!OBS.instrument', self.cmds)}/"
f"{os.path.dirname(trace_param_filename)}/"
f"analytical_echelle_traces.fits", overwrite=True)
kwargs["hdulist"] = hdulist
super().__init__(**kwargs)
super().__init__(cmds=cmds, **kwargs)

def _generate_trace_hdulist(self, trace_params):
hdul = fits.HDUList()
Expand Down Expand Up @@ -588,54 +593,49 @@ def _generate_trace_hdulist(self, trace_params):
max_order = row['m0']
min_wave = row['min_wave'] * u.Unit(trace_params.meta["min_wave_unit"])
max_wave = row['max_wave'] * u.Unit(trace_params.meta["max_wave_unit"])
design_res = row['design_res']
focal_len = row['focal_length'] * u.Unit(trace_params.meta["focal_length_unit"])
xdisp_npix = row['n_xdisp']
disp_npix = row['n_disp'] - 2 * row['detector_pad']
xdisp_npix = row['n_xdisp'] - 2 * row['detector_pad']
pix_size = row['pixel_size'] * u.Unit(trace_params.meta["pixel_size_unit"])
x_disp_len = (xdisp_npix - 2 * row['detector_pad']) * pix_size
echelle_angle = np.deg2rad(row['echelle_blaze'])
alpha = np.deg2rad(row['alpha'])
beta_center = np.deg2rad(row['beta_center'])
# cross_disperser = echelle.GratingSetup(
# groove_length=u.Unit(trace_params.meta["xdisp_freq_unit"]) / row['xdisp_freq'],
# guess_littrow=(min_wave, max_wave,
# x_disp_len, focal_len))
cross_disperser = echelle.GratingSetup(alpha=alpha, beta_center=beta_center,
delta=beta_center,
groove_length=u.Unit(trace_params.meta["xdisp_freq_unit"]) / row['xdisp_freq'])

ss = echelle.SpectrographSetup((min_order, max_order),
max_wave,
row['fwhm'] * u.Unit(trace_params.meta["fwhm_unit"]),
focal_len,
echelle.GratingSetup(alpha=echelle_angle, beta_center=echelle_angle,
delta=echelle_angle,
groove_length=u.Unit(trace_params.meta["disp_freq_unit"]) / row['disp_freq']),
echelle.Detector(row['n_disp'], xdisp_npix, pix_size),
cross_disperser=cross_disperser
)

fsr_edges = ss.edge_wave(fsr=True)

slit_edge = (row['slitwidth'] / 2) * u.Unit(trace_params.meta["slitwidth_unit"])
echelle_angle = np.deg2rad(row['echelle_blaze'])*u.rad
xdisp_beta_center = np.deg2rad(row['xbeta_center'])*u.rad

xdisp_groove_length = u.Unit(trace_params.meta["xdisp_freq_unit"]) / row['xdisp_freq']
echelle_groove_length = u.Unit(trace_params.meta["disp_freq_unit"]) / row['disp_freq']
pix_per_res_elem = row['fwhm']

ss = echelle.spectrograph_factory(min_wave, max_wave, focal_len,
design_res, echelle_angle, min_order, max_order,
echelle_groove_length, pix_per_res_elem, disp_npix, xdisp_npix,
pix_size, xdisp_groove_length=xdisp_groove_length,
xdisp_beta_center=xdisp_beta_center)

edges = ss.edge_wave(fsr=False)

slit_edge = (row['slitlength'] / 2) * u.Unit(trace_params.meta["slitlength_unit"])
slit_pos = np.linspace(-slit_edge, slit_edge, num=3)
slit_offset_pix = slit_pos / (row['plate_scale'] * u.arcsec)
slit_offset_pix = slit_pos / (from_currsys('!INST.pixel_scale', self.cmds) * u.arcsec)

xvals, yvals = [], []
for i, order in enumerate(ss.orders):
wave = fsr_edges[i]
wave = edges[i]
wave = np.linspace(wave[0], wave[-1], num=max(int(disp_npix*.1), 2))
x = ss.wavelength_to_x_pixel(wave, order)
y = ss.wavelength_to_y_pixel(wave)
pix_y = y + row['detector_pad'] + slit_offset_pix[:, None]
pix_y = y + slit_offset_pix[:, None] + row['detector_pad']
xval = np.tile(x, slit_offset_pix.size)*pix_size.to('mm')
yval = pix_y.ravel()*pix_size.to('mm')
xvals.append(xval)
yvals.append(yval)

# echelle above has 0,0 at detector corner, Scopesim uses 0,0 at detector center
xcent = (np.min(xvals) + (np.max(xvals) - np.min(xvals))/2) * u.mm
ycent = (np.min(yvals) + (np.max(yvals) - np.min(yvals))/2) * u.mm

for i, order in enumerate(ss.orders):
wave = fsr_edges[i]
wave = edges[i]
wave = np.linspace(wave[0], wave[-1], num=max(int(disp_npix*.1), 2))
s = np.tile(slit_pos, wave.size).reshape(wave.size, slit_pos.size).T.ravel()
w = np.tile(wave, slit_offset_pix.size)
xval = xvals[i] - xcent # Centering on 0,0 at detector center
Expand Down
Loading