Source code for lasif.components.adjoint_sources

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import

import pyasdf
import obspy
import os
import numpy as np
from lasif.utils import process_two_files_without_parallel_output
import toml

from lasif.exceptions import LASIFNotFoundError, LASIFError
from .component import Component
from lasif.tools.adjoint.adjoint_source import calculate_adjoint_source
from lasif.utils import select_component_from_stream

TAUPY_MODEL_CACHE = {}


[docs]class AdjointSourcesComponent(Component): """ Component dealing with the adjoint sources. :param folder: The folder where the files are stored. :param communicator: The communicator instance. :param component_name: The name of this component for the communicator. """ def __init__(self, folder, communicator, component_name): self._folder = folder super(AdjointSourcesComponent, self).__init__( communicator, component_name )
[docs] def get_filename(self, event: str, iteration: str): """ Gets the filename for the adjoint source. :param event: The event. :type event: str :param iteration: The iteration name. :type iteration: str """ iteration_long_name = self.comm.iterations.get_long_iteration_name( iteration ) folder = os.path.join(self._folder, iteration_long_name, event) if not os.path.exists(folder): os.makedirs(folder) return os.path.join(folder, "adjoint_source_auxiliary.h5")
[docs] def get_misfit_file(self, iteration: str): """ Get path to the iteration misfit file :param iteration: Name of iteration :type iteration: str """ iteration_name = self.comm.iterations.get_long_iteration_name( iteration ) file = ( self.comm.project.paths["iterations"] / iteration_name / "misfits.toml" ) if not os.path.exists(file): raise LASIFNotFoundError(f"File {file} does not exist") return file
[docs] def get_misfit_for_event( self, event: str, iteration: str, weight_set_name: str = None, include_station_misfit: bool = False, ): """ This function returns the total misfit for an event. :param event: name of the event :type event: str :param iteration: iteration for which to get the misfit :type iteration: str :param weight_set_name: Name of station weights, defaults to None :type weight_set_name: str, optional :param include_station_misfit: Whether individual station misfits should be written down or not, defaults to False :type include_station_misfit: bool, optional """ misfit_file = self.get_misfit_file(iteration) misfits = toml.load(misfit_file) if event not in misfits.keys(): raise LASIFError( f"Misfit has not been computed for event {event}, " f"iteration: {iteration}. " ) event_misfit = misfits[event]["event_misfit"] if include_station_misfit: return misfits[event] else: return event_misfit
[docs] def calculate_validation_misfits(self, event: str, iteration: str): """ This fuction computed the L2 weighted waveform misfit over a whole trace. It is meant to provide misfits for validation purposes. E.g. to steer regularization parameters. :param event: name of the event :type event: str :param iteration: iteration for which to get the misfit :type iteration: str """ from scipy.integrate import simps from obspy import geodetics from lasif.utils import progress min_sn_ratio = 0.05 event = self.comm.events.get(event) # Fill cache if necessary. if not TAUPY_MODEL_CACHE: from obspy.taup import TauPyModel # NOQA TAUPY_MODEL_CACHE["model"] = TauPyModel("AK135") model = TAUPY_MODEL_CACHE["model"] # Get the ASDF filenames. processed_filename = self.comm.waveforms.get_asdf_filename( event_name=event["event_name"], data_type="processed", tag_or_iteration=self.comm.waveforms.preprocessing_tag, ) synthetic_filename = self.comm.waveforms.get_asdf_filename( event_name=event["event_name"], data_type="synthetic", tag_or_iteration=iteration, ) dt = self.comm.project.simulation_settings["time_step_in_s"] ds_syn = pyasdf.ASDFDataSet(synthetic_filename, mode="r", mpi=False) ds_obs = pyasdf.ASDFDataSet(processed_filename, mode="r", mpi=False) event_latitude = event["latitude"] event_longitude = event["longitude"] event_depth_in_km = event["depth_in_km"] minimum_period = self.comm.project.simulation_settings[ "minimum_period_in_s"] misfit = 0.0 for i, station in enumerate(ds_obs.waveforms.list()): if i % 30 == 0: progress(i+1, len(ds_obs.waveforms.list()), status="Computing misfits") observed_station = ds_obs.waveforms[station] synthetic_station = ds_syn.waveforms[station] obs_tag = observed_station.get_waveform_tags() syn_tag = synthetic_station.get_waveform_tags() try: # Make sure both have length 1. assert len(obs_tag) == 1, ( "Station: %s - Requires 1 observed waveform tag." " Has %i." % (observed_station._station_name, len(obs_tag)) ) except AssertionError: continue assert len(syn_tag) == 1, ( "Station: %s - Requires 1 synthetic waveform tag. " "Has %i." % (observed_station._station_name, len(syn_tag)) ) obs_tag = obs_tag[0] syn_tag = syn_tag[0] station_latitude = observed_station.coordinates["latitude"] station_longitude = observed_station.coordinates["longitude"] st_obs = observed_station[obs_tag] st_syn = synthetic_station[syn_tag] # Sample points down to 10 points per minimum_period # len_s = st_obs[0].stats.endtime - st_obs[0].stats.starttime # num_samples_wavelength = 10.0 # new_sampling_rate = num_samples_wavelength * \ # minimum_period / len_s # st_obs = st_obs.resample(new_sampling_rate) # st_syn = st_syn.resample(new_sampling_rate) # dt = 1.0/new_sampling_rate dist_in_deg = geodetics.locations2degrees( station_latitude, station_longitude, event_latitude, event_longitude ) # Get only a couple of P phases which should be the # first arrival # for every epicentral distance. Its quite a bit faster # than calculating # the arrival times for every phase. # Assumes the first sample is the centroid time of the event. ttp = model.get_travel_times( source_depth_in_km=event_depth_in_km, distance_in_degree=dist_in_deg, phase_list=["ttp"], ) # Sort just as a safety measure. ttp = sorted(ttp, key=lambda x: x.time) first_tt_arrival = ttp[0].time # Estimate noise level from waveforms prior to the # first arrival. idx_end = int(np.ceil((first_tt_arrival - 0.5 * minimum_period) / dt)) idx_end = max(10, idx_end) idx_start = int( np.ceil((first_tt_arrival - 2.5 * minimum_period) / dt)) idx_start = max(10, idx_start) if idx_start >= idx_end: idx_start = max(0, idx_end - 10) for component in ["E", "N", "Z"]: try: data_tr = select_component_from_stream(st_obs, component) synth_tr = select_component_from_stream(st_syn, component) except LASIFNotFoundError: continue # Scale data to synthetics scaling_factor = (synth_tr.data.ptp() / data_tr.data.ptp()) if np.isinf(scaling_factor): continue # Store and apply the scaling. data_tr.stats.scaling_factor = scaling_factor data_tr.data *= scaling_factor data = data_tr.data abs_data = np.abs(data) noise_absolute = abs_data[idx_start:idx_end].max() noise_relative = noise_absolute / abs_data.max() if noise_relative > min_sn_ratio: continue # normalize the trace to [-1,1], reduce source effects # and balance amplitudes norm_scaling_fac = 1.0 / np.max(np.abs(synth_tr.data)) data_tr.data *= norm_scaling_fac synth_tr.data *= norm_scaling_fac # envelope = obspy.signal.filter.envelope(data_tr.data) # scale up to around 1, also never divide by 0 # by adding regularization term, dependent on noise level # env_weighting = 1.0 / ( # envelope + np.max(envelope) * 0.2) # data_tr.data *= env_weighting # synth_tr.data *= env_weighting diff = data_tr.data - synth_tr.data misfit += 0.5 * simps(y=diff ** 2, dx=data_tr.stats.delta) print("\nTotal event misfit: ", misfit) return misfit
[docs] def calculate_adjoint_sources_multiprocessing( self, event: str, iteration: str, window_set_name: str, num_processes: int, plot: bool = False, **kwargs, ): """ Calculate adjoint sources based on the type of misfit defined in the lasif config file. The computed misfit for each station is also written down into a misfit toml file. This function uses multiprocessing for parallelization :param event: Name of event :type event: str :param iteration: Name of iteration :type iteration: str :param window_set_name: Name of window set :type window_set_name: str :param num_processes: The number of processes used in multiprocessing :type num_processes: int :param plot: Should the adjoint source be plotted?, defaults to False :type plot: bool, optional """ from lasif.utils import select_component_from_stream from tqdm import tqdm import multiprocessing import warnings warnings.filterwarnings("ignore") # Globally define the processing function. This is required to enable # pickling of a function within a function. Alternatively, a solution # can be found that does not utilize a function within a function. global _process event = self.comm.events.get(event) # Get the ASDF filenames. processed_filename = self.comm.waveforms.get_asdf_filename( event_name=event["event_name"], data_type="processed", tag_or_iteration=self.comm.waveforms.preprocessing_tag, ) synthetic_filename = self.comm.waveforms.get_asdf_filename( event_name=event["event_name"], data_type="synthetic", tag_or_iteration=iteration, ) if not os.path.exists(processed_filename): msg = "File '%s' does not exists." % processed_filename raise LASIFNotFoundError(msg) if not os.path.exists(synthetic_filename): msg = "File '%s' does not exists." % synthetic_filename raise LASIFNotFoundError(msg) all_windows = self.comm.windows.read_all_windows( event=event["event_name"], window_set_name=window_set_name ) process_params = self.comm.project.simulation_settings def _process(station): ds = pyasdf.ASDFDataSet(processed_filename, mode="r", mpi=False) ds_synth = pyasdf.ASDFDataSet(synthetic_filename, mode="r", mpi=False) observed_station = ds.waveforms[station] synthetic_station = ds_synth.waveforms[station] # print(observed_station, synthetic_station) obs_tag = observed_station.get_waveform_tags() syn_tag = synthetic_station.get_waveform_tags() adjoint_sources = {} try: # Make sure both have length 1. assert len(obs_tag) == 1, ( "Station: %s - Requires 1 observed waveform tag. Has %i." % (observed_station._station_name, len(obs_tag)) ) assert len(syn_tag) == 1, ( "Station: %s - Requires 1 synthetic waveform tag. Has %i." % (observed_station._station_name, len(syn_tag)) ) except AssertionError: return {station: adjoint_sources} obs_tag = obs_tag[0] syn_tag = syn_tag[0] # Finally get the data. st_obs = observed_station[obs_tag] st_syn = synthetic_station[syn_tag] # Process the synthetics. st_syn = self.comm.waveforms.process_synthetics( st=st_syn.copy(), event_name=event["event_name"], iteration=iteration, ) ad_src_type = self.comm.project.optimization_settings[ "misfit_type" ] if ad_src_type == "weighted_waveform_misfit": env_scaling = True ad_src_type = "waveform_misfit" else: env_scaling = False for component in ["E", "N", "Z"]: try: data_tr = select_component_from_stream(st_obs, component) synth_tr = select_component_from_stream(st_syn, component) except LASIFNotFoundError: continue if self.comm.project.simulation_settings[ "scale_data_to_synthetics"]: if (not self.comm.project.optimization_settings[ "misfit_type"] == "L2NormWeighted"): scaling_factor = ( synth_tr.data.ptp() / data_tr.data.ptp() ) # Store and apply the scaling. data_tr.stats.scaling_factor = scaling_factor data_tr.data *= scaling_factor net, sta, cha = data_tr.id.split(".", 2) station = net + "." + sta if station not in all_windows: continue if data_tr.id not in all_windows[station]: continue # Collect all. windows = all_windows[station][data_tr.id] try: # for window in windows: asrc = calculate_adjoint_source( observed=data_tr, synthetic=synth_tr, window=windows, min_period=process_params["minimum_period_in_s"], max_period=process_params["maximum_period_in_s"], adj_src_type=ad_src_type, window_set=window_set_name, taper_ratio=0.15, taper_type="cosine", plot=plot, envelope_scaling=env_scaling, ) except: # Either pass or fail for the whole component. continue if not asrc: continue # Sum up both misfit, and adjoint source. misfit = asrc.misfit adj_source = asrc.adjoint_source.data adjoint_sources[data_tr.id] = { "misfit": misfit, "adj_source": adj_source, } adj_dict = {station: adjoint_sources} return adj_dict ds = pyasdf.ASDFDataSet(processed_filename, mode="r", mpi=False) # Generate task list task_list = ds.waveforms.list() # Use at most num_processes number_processes = min(num_processes, multiprocessing.cpu_count()) with multiprocessing.Pool(number_processes) as pool: results = {} with tqdm(total=len(task_list)) as pbar: for i, r in enumerate(pool.imap_unordered(_process, task_list)): pbar.update() k, v = r.popitem() results[k] = v pool.close() pool.join() # Write adjoint sources filename = self.get_filename( event=event["event_name"], iteration=iteration ) long_iter_name = self.comm.iterations.get_long_iteration_name( iteration ) misfit_toml = self.comm.project.paths["iterations"] toml_filename = misfit_toml / long_iter_name / "misfits.toml" ad_src_counter = 0 if os.path.exists(toml_filename): iteration_misfits = toml.load(toml_filename) if event["event_name"] in iteration_misfits.keys(): iteration_misfits[event["event_name"]][ "event_misfit" ] = 0.0 with open(toml_filename, "w") as fh: toml.dump(iteration_misfits, fh) print("Writing adjoint sources...") with pyasdf.ASDFDataSet(filename=filename, mpi=False, mode="a") as bs: if toml_filename.exists(): iteration_misfits = toml.load(toml_filename) if event["event_name"] in iteration_misfits.keys(): total_misfit = iteration_misfits[ event["event_name"] ]["event_misfit"] else: iteration_misfits[event["event_name"]] = {} iteration_misfits[event["event_name"]][ "stations" ] = {} total_misfit = 0.0 else: iteration_misfits = {} iteration_misfits[event["event_name"]] = {} iteration_misfits[event["event_name"]]["stations"] = {} total_misfit = 0.0 for value in results.values(): if not value: continue station_misfit = 0.0 for c_id, adj_source in value.items(): net, sta, loc, cha = c_id.split(".") bs.add_auxiliary_data( data=adj_source["adj_source"], data_type="AdjointSources", path="%s_%s/Channel_%s_%s" % (net, sta, loc, cha), parameters={"misfit": adj_source["misfit"]}, ) station_misfit += adj_source["misfit"] station_name = f"{net}.{sta}" iteration_misfits[event["event_name"]]["stations"][ station_name ] = float(station_misfit) ad_src_counter += 1 total_misfit += station_misfit iteration_misfits[event["event_name"]][ "event_misfit" ] = float(total_misfit) with open(toml_filename, "w") as fh: toml.dump(iteration_misfits, fh) with pyasdf.ASDFDataSet( filename=filename, mpi=False, mode="a" ) as ds: length = len(ds.auxiliary_data.AdjointSources.list()) print(f"{length} Adjoint sources are in your file.")
[docs] def calculate_adjoint_sources( self, event: str, iteration: str, window_set_name: str, plot: bool = False, **kwargs, ): """ Calculate adjoint sources based on the type of misfit defined in the lasif config file. The computed misfit for each station is also written down into a misfit toml file. :param event: Name of event :type event: str :param iteration: Name of iteration :type iteration: str :param window_set_name: Name of window set :type window_set_name: str :param plot: Should the adjoint source be plotted?, defaults to False :type plot: bool, optional """ from lasif.utils import select_component_from_stream from mpi4py import MPI import pyasdf event = self.comm.events.get(event) # Get the ASDF filenames. processed_filename = self.comm.waveforms.get_asdf_filename( event_name=event["event_name"], data_type="processed", tag_or_iteration=self.comm.waveforms.preprocessing_tag, ) synthetic_filename = self.comm.waveforms.get_asdf_filename( event_name=event["event_name"], data_type="synthetic", tag_or_iteration=iteration, ) if not os.path.exists(processed_filename): msg = "File '%s' does not exists." % processed_filename raise LASIFNotFoundError(msg) if not os.path.exists(synthetic_filename): msg = "File '%s' does not exists." % synthetic_filename raise LASIFNotFoundError(msg) # Read all windows on rank 0 and broadcast. if MPI.COMM_WORLD.rank == 0: all_windows = self.comm.windows.read_all_windows( event=event["event_name"], window_set_name=window_set_name ) else: all_windows = {} all_windows = MPI.COMM_WORLD.bcast(all_windows, root=0) process_params = self.comm.project.simulation_settings def process(observed_station, synthetic_station): obs_tag = observed_station.get_waveform_tags() syn_tag = synthetic_station.get_waveform_tags() # Make sure both have length 1. assert len(obs_tag) == 1, ( "Station: %s - Requires 1 observed waveform tag. Has %i." % (observed_station._station_name, len(obs_tag)) ) assert len(syn_tag) == 1, ( "Station: %s - Requires 1 synthetic waveform tag. Has %i." % (observed_station._station_name, len(syn_tag)) ) obs_tag = obs_tag[0] syn_tag = syn_tag[0] # Finally get the data. st_obs = observed_station[obs_tag] st_syn = synthetic_station[syn_tag] # Process the synthetics. st_syn = self.comm.waveforms.process_synthetics( st=st_syn.copy(), event_name=event["event_name"], iteration=iteration, ) adjoint_sources = {} ad_src_type = self.comm.project.optimization_settings[ "misfit_type" ] if ad_src_type == "weighted_waveform_misfit": env_scaling = True ad_src_type = "waveform_misfit" else: env_scaling = False for component in ["E", "N", "Z"]: try: data_tr = select_component_from_stream(st_obs, component) synth_tr = select_component_from_stream(st_syn, component) except LASIFNotFoundError: continue if self.comm.project.simulation_settings[ "scale_data_to_synthetics" ]: if ( not self.comm.project.optimization_settings[ "misfit_type" ] == "L2NormWeighted" ): scaling_factor = ( synth_tr.data.ptp() / data_tr.data.ptp() ) # Store and apply the scaling. data_tr.stats.scaling_factor = scaling_factor data_tr.data *= scaling_factor net, sta, cha = data_tr.id.split(".", 2) station = net + "." + sta if station not in all_windows: continue if data_tr.id not in all_windows[station]: continue # Collect all. windows = all_windows[station][data_tr.id] try: # for window in windows: asrc = calculate_adjoint_source( observed=data_tr, synthetic=synth_tr, window=windows, min_period=process_params["minimum_period_in_s"], max_period=process_params["maximum_period_in_s"], adj_src_type=ad_src_type, window_set=window_set_name, taper_ratio=0.15, taper_type="cosine", plot=plot, envelope_scaling=env_scaling, ) except: # Either pass or fail for the whole component. continue if not asrc: continue # Sum up both misfit, and adjoint source. misfit = asrc.misfit adj_source = asrc.adjoint_source.data adjoint_sources[data_tr.id] = { "misfit": misfit, "adj_source": adj_source, } return adjoint_sources ds = pyasdf.ASDFDataSet(processed_filename, mode="r", mpi=False) ds_synth = pyasdf.ASDFDataSet(synthetic_filename, mode="r", mpi=False) # Launch the processing. This will be executed in parallel across # ranks. results = process_two_files_without_parallel_output( ds, ds_synth, process ) # Write files on all ranks. filename = self.get_filename( event=event["event_name"], iteration=iteration ) long_iter_name = self.comm.iterations.get_long_iteration_name( iteration ) misfit_toml = self.comm.project.paths["iterations"] toml_filename = misfit_toml / long_iter_name / "misfits.toml" ad_src_counter = 0 size = MPI.COMM_WORLD.size if MPI.COMM_WORLD.rank == 0: if os.path.exists(toml_filename): iteration_misfits = toml.load(toml_filename) if event["event_name"] in iteration_misfits.keys(): iteration_misfits[event["event_name"]][ "event_misfit" ] = 0.0 with open(toml_filename, "w") as fh: toml.dump(iteration_misfits, fh) MPI.COMM_WORLD.Barrier() for thread in range(size): rank = MPI.COMM_WORLD.rank if rank == thread: print( f"Writing adjoint sources for rank: {rank+1} " f"out of {size}", flush=True, ) with pyasdf.ASDFDataSet( filename=filename, mpi=False, mode="a" ) as bs: if toml_filename.exists(): iteration_misfits = toml.load(toml_filename) if event["event_name"] in iteration_misfits.keys(): total_misfit = iteration_misfits[ event["event_name"] ]["event_misfit"] else: iteration_misfits[event["event_name"]] = {} iteration_misfits[event["event_name"]][ "stations" ] = {} total_misfit = 0.0 else: iteration_misfits = {} iteration_misfits[event["event_name"]] = {} iteration_misfits[event["event_name"]]["stations"] = {} total_misfit = 0.0 for value in results.values(): if not value: continue station_misfit = 0.0 for c_id, adj_source in value.items(): net, sta, loc, cha = c_id.split(".") bs.add_auxiliary_data( data=adj_source["adj_source"], data_type="AdjointSources", path="%s_%s/Channel_%s_%s" % (net, sta, loc, cha), parameters={"misfit": adj_source["misfit"]}, ) station_misfit += adj_source["misfit"] station_name = f"{net}.{sta}" iteration_misfits[event["event_name"]]["stations"][ station_name ] = float(station_misfit) ad_src_counter += 1 total_misfit += station_misfit iteration_misfits[event["event_name"]][ "event_misfit" ] = float(total_misfit) with open(toml_filename, "w") as fh: toml.dump(iteration_misfits, fh) MPI.COMM_WORLD.barrier() if MPI.COMM_WORLD.rank == 0: with pyasdf.ASDFDataSet( filename=filename, mpi=False, mode="a" ) as ds: length = len(ds.auxiliary_data.AdjointSources.list()) print(f"{length} Adjoint sources are in your file.")
[docs] def finalize_adjoint_sources( self, iteration_name: str, event_name: str, weight_set_name: str = None ): """ Work with adjoint source in a way that it is written down properly into an hdf5 file and prepared for being used as a source time function. The misfit values and adjoint sources are multiplied by the weight of the event and the station. :param iteration_name: Name of iteration :type iteration_name: str :param event_name: Name of event :type event_name: str :param weight_set_name: Name of station weights, defaults to None :type weight_set_name: str, optional """ import pyasdf import h5py print("Finalizing adjoint sources...") # This will do stuff for each event and a single iteration # Step one, read adj_src file that should have been created already iteration = self.comm.iterations.get_long_iteration_name( iteration_name ) adj_src_file = self.get_filename(event_name, iteration) ds = pyasdf.ASDFDataSet(adj_src_file, mpi=False) adj_srcs = ds.auxiliary_data["AdjointSources"] input_files_dir = self.comm.project.paths["adjoint_sources"] receivers = self.comm.query.get_all_stations_for_event(event_name) output_dir = os.path.join(input_files_dir, iteration, event_name) if not os.path.exists(output_dir): os.mkdir(output_dir) adjoint_source_file_name = os.path.join(output_dir, "stf.h5") f = h5py.File(adjoint_source_file_name, "w") event_weight = 1.0 if weight_set_name is not None: ws = self.comm.weights.get(weight_set_name) event_weight = ws.events[event_name]["event_weight"] station_weights = ws.events[event_name]["stations"] computed_misfits = toml.load(self.get_misfit_file(iteration)) for adj_src in adj_srcs: station_name = adj_src.auxiliary_data_type.split("/")[1] channels = adj_src.list() e_comp = np.zeros_like(adj_src[channels[0]].data[()]) n_comp = np.zeros_like(adj_src[channels[0]].data[()]) z_comp = np.zeros_like(adj_src[channels[0]].data[()]) for channel in channels: # check channel and set component if channel[-1] == "E": e_comp = adj_src[channel].data[()] elif channel[-1] == "N": n_comp = adj_src[channel].data[()] elif channel[-1] == "Z": z_comp = adj_src[channel].data[()] zne = np.array((z_comp, n_comp, e_comp)) for receiver in receivers.keys(): station = receiver.replace(".", "_") # station = receiver["network"] + "_" + receiver["station"] if station == station_name: # transform_mat = np.array(receiver["transform_matrix"]) # xyz = np.dot(transform_mat.T, zne).T # net_dot_sta = \ # receiver["network"] + "." + receiver["station"] if weight_set_name is not None: weight = ( station_weights[receiver]["station_weight"] * event_weight ) zne *= weight computed_misfits[event_name]["stations"][ receiver ] *= weight source = f.create_dataset(station, data=zne.T) source.attrs["dt"] = self.comm.project.simulation_settings[ "time_step_in_s" ] source.attrs["sampling_rate_in_hertz"] = ( 1 / source.attrs["dt"] ) # source.attrs['location'] = np.array( # [receivers[receiver]["s"]]) source.attrs["spatial-type"] = np.string_("vector") # Start time in nanoseconds source.attrs[ "start_time_in_seconds" ] = self.comm.project.simulation_settings[ "start_time_in_s" ] # toml_string += f"[[source]]\n" \ # f"name = \"{station}\"\n" \ # f"dataset_name = \"/{station}\"\n\n" if weight_set_name is not None: computed_misfits[event_name]["event_misfit"] = np.sum( np.array( list(computed_misfits[event_name]["stations"].values()) ) ) with open(self.get_misfit_file(iteration), "w") as fh: toml.dump(computed_misfits, fh) f.close()
@staticmethod def _validate_return_value(adsrc): if not isinstance(adsrc, dict): return False elif sorted(adsrc.keys()) != [ "adjoint_source", "details", "misfit_value", ]: return False elif not isinstance(adsrc["adjoint_source"], np.ndarray): return False elif not isinstance(adsrc["misfit_value"], float): return False elif not isinstance(adsrc["details"], dict): return False return True