Source code for prediction

"""
This module handles the prediction phase using a pre-trained Keras neural network model.
It loads the catalog, pre-processes the data (normalization and feature engineering),
and generates predictions (AGN vs Pulsar) based on the provided model.
"""

import argparse
import sys
from pathlib import Path

# pylint: disable=import-error, wrong-import-position
from loguru import logger
import numpy as np
import pandas as pd
import keras
from sklearn.preprocessing import StandardScaler

git_dir = None
for i in Path(__file__).parents:
    for j in i.iterdir():
        if ".git" in j.as_posix() and j.is_dir():
            git_dir = i
if git_dir is None:
    raise FileNotFoundError(
        "Git Directory Not Found. Please ensure that you cloned the repository in the right way."
    )
import_dir = git_dir / "imports/"
sys.path.append(import_dir.as_posix())
import custom_variables as custom_paths

# pylint: enable=import-error, wrong-import-position

[docs] def model_prediction( catalog_path=custom_paths.csv_path, model_path=custom_paths.model_path, threshold = 0.63 ): """ Performs predictions on the provided astronomical catalog using a trained Keras model. The function loads the catalog and the model, normalizes the input features using StandardScaler, and computes the classification (AGN or Pulsar) based on the specified threshold. :param catalog_path: Path to the input CSV catalog containing source data. Defaults to `custom_paths.csv_path`. :type catalog_path: str or pathlib.Path, optional :param model_path: Path to the saved Keras model (.keras). Defaults to `custom_paths.model_path`. :type model_path: str or pathlib.Path, optional :param threshold: The decision threshold for the binary classification. Values >= threshold are classified as Pulsar, otherwise AGN. Defaults to 0.63. :type threshold: float, optional :return: An array of string labels ('AGN' or 'Pulsar') corresponding to the predictions. :rtype: numpy.ndarray """ logger.info("Importing Model..") model = keras.models.load_model(model_path) logger.info("Importing Catalog") df = pd.read_csv(catalog_path) col_input1 = [ "GLAT", "Variability_Index", "PowerLaw", "LogParabola", "PLSuperExpCutoff", ] col_flux_band = np.array([[f"Flux_Band_{i}", f"Sqrt_TS_Band_{i}"] for i in range(8)]) col_flux_hist = np.array([[f"Flux_History_{i}", f"Sqrt_TS_History_{i}"] for i in range(14)]) norm_cols = np.array(list(col_flux_band.flatten()) + list(col_flux_hist.flatten())) scaler = StandardScaler() scaler.fit(df[norm_cols]) scaled_data = scaler.transform(df[norm_cols]) df[norm_cols] = scaled_data input_additional = df[col_input1].to_numpy() input_flux_band = df[col_flux_band.flatten()].to_numpy() input_flux_hist = df[col_flux_hist.flatten()].to_numpy() logger.info("Starting Predictions...") predictions = model.predict([input_flux_band, input_flux_hist, input_additional]) th_pred = (predictions >= threshold).astype(int) th_pred = np.where( th_pred == 0, "AGN", "Pulsar", ) th_pred = th_pred.reshape(len(th_pred)) return th_pred
if __name__ == "__main__": parser = argparse.ArgumentParser( description="This script can be executed to generate predictions with given model." ) parser.add_argument( "--csv_path", "-i", default=f"{custom_paths.csv_path}", help="Path to the input CSV catalog containing the source data to be classified.", ) parser.add_argument( "--threshold", "-th", default=0.63, help="The probability threshold used to distinguish between AGN and Pulsar classes.\nSources with a probability higher than this value are classified as Pulsars.", ) parser.add_argument( "--model_path", "-m", default=f"{custom_paths.model_path}", help="Path to the pre-trained Keras model file (.keras) to be used for inference.", ) parser.add_argument( "--prediction_path", "-p", default=f"{custom_paths.prediction_path}", help="Output path where the resulting predictions array (.npy) will be saved.", ) args = parser.parse_args() preds = model_prediction( catalog_path= args.csv_path, model_path=args.model_path, threshold = args.threshold ) logger.info("Saving Predictions..") np.save(args.prediction_path,preds)