pybird.integrated_model_jax module
- class pybird.integrated_model_jax.CustomActivation_jax(*args: Any, **kwargs: Any)[source]
Bases:
Module
- class pybird.integrated_model_jax.IntegratedModel(keras_model, input_scaler, output_scaler, temp_file=None, offset=None, log_preprocess=False, zero_columns=None, rescaling_factor=None, pca=None, pca_scaler=None, verbose=False)[source]
Bases:
object
A class to integrate and manage Flax models in JAX with preprocessing.
The IntegratedModel class combines a Flax model with input/output scaling, PCA transformations, and other preprocessing steps. It handles model predictions with proper scaling and transformation, and can restore models from saved files.
- model
The Flax model for predictions.
- Type:
flax.linen.Module
- scaler_mean_in
Mean values for input scaling.
- Type:
ndarray
- scaler_scale_in
Scale values for input scaling.
- Type:
ndarray
- scaler_mean_out
Mean values for output scaling.
- Type:
ndarray
- scaler_scale_out
Scale values for output scaling.
- Type:
ndarray
- pca_components
PCA components if used.
- Type:
ndarray
- pca_mean
PCA mean values if used.
- Type:
ndarray
- pca_scaler_mean
Mean values for PCA scaling.
- Type:
ndarray
- pca_scaler_scale
Scale values for PCA scaling.
- Type:
ndarray