pybird.integrated_model_jax module

class pybird.integrated_model_jax.Network(*args: Any, **kwargs: Any)[source]

Bases: Module

weights: list
hyper_params: list
class pybird.integrated_model_jax.CustomActivation_jax(*args: Any, **kwargs: Any)[source]

Bases: Module

a: float
b: float
pybird.integrated_model_jax.insert_zero_columns(prediction, zero_columns_indices)[source]
class pybird.integrated_model_jax.IntegratedModel(keras_model, input_scaler, output_scaler, temp_file=None, offset=None, log_preprocess=False, zero_columns=None, rescaling_factor=None, pca=None, pca_scaler=None, verbose=False)[source]

Bases: object

A class to integrate and manage Flax models in JAX with preprocessing.

The IntegratedModel class combines a Flax model with input/output scaling, PCA transformations, and other preprocessing steps. It handles model predictions with proper scaling and transformation, and can restore models from saved files.

model

The Flax model for predictions.

Type:

flax.linen.Module

input_scaler

Scaler for input data preprocessing.

Type:

object

output_scaler

Scaler for output data postprocessing.

Type:

object

offset

Offset value for output adjustment.

Type:

float

zero_columns

Indices of columns to be set to zero in output.

Type:

list

rescaling_factor

Factor for rescaling output.

Type:

float

temp_file

Path to temporary file for model storage.

Type:

str

train_losses

History of training losses.

Type:

list

val_losses

History of validation losses.

Type:

list

log_preprocess

Whether to apply log preprocessing.

Type:

bool

pca

PCA transformation object if used.

Type:

object

pca_scaler

Scaler for PCA transformed data.

Type:

object

verbose

Whether to print verbose information.

Type:

bool

scaler_mean_in

Mean values for input scaling.

Type:

ndarray

scaler_scale_in

Scale values for input scaling.

Type:

ndarray

scaler_mean_out

Mean values for output scaling.

Type:

ndarray

scaler_scale_out

Scale values for output scaling.

Type:

ndarray

pca_components

PCA components if used.

Type:

ndarray

pca_mean

PCA mean values if used.

Type:

ndarray

pca_scaler_mean

Mean values for PCA scaling.

Type:

ndarray

pca_scaler_scale

Scale values for PCA scaling.

Type:

ndarray

predict()[source]

Make predictions with proper preprocessing and postprocessing.

restore()[source]

Restore model parameters and scalers from a saved file.

predict(data)[source]
restore(h5_filename)[source]

Load pre-saved IntegratedModel attributes’ from an h5 file. :param h5_filename: filename of the .h5 file where model was saved :type h5_filename: str