Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Safoora Yousefi committed Jan 17, 2025
1 parent f52164c commit 15d156f
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions eureka_ml_insights/core/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,21 @@

from .pipeline import Component
from .reserved_names import INFERENCE_RESERVED_NAMES

MINUTE = 60


class Inference(Component):
def __init__(self, model_config, data_config, output_dir, resume_from=None, new_columns=None, requests_per_minute=None, max_concurrent=1):

def __init__(
self,
model_config,
data_config,
output_dir,
resume_from=None,
new_columns=None,
requests_per_minute=None,
max_concurrent=1,
):
"""
Initialize the Inference component.
args:
Expand Down Expand Up @@ -62,13 +71,13 @@ def fetch_previous_inference_results(self):
# fetch previous results from the provided resume_from file
logging.info(f"Resuming inference from {self.resume_from}")
pre_inf_results_df = DataReader(self.resume_from, format=".jsonl").load_dataset()

# add new columns listed by the user to the previous inference results
if self.new_columns:
for col in self.new_columns:
if col not in pre_inf_results_df.columns:
pre_inf_results_df[col] = None

# validate the resume_from contents
with self.data_loader as loader:
_, sample_model_input = self.data_loader.get_sample_model_input()
Expand All @@ -81,14 +90,16 @@ def fetch_previous_inference_results(self):
# perform a sample inference call to get the model output keys and validate the resume_from contents
sample_response_dict = self.model.generate(*sample_model_input)
if not sample_response_dict["is_valid"]:
raise ValueError("Sample inference call for resume_from returned invalid results, please check the model configuration.")
raise ValueError(
"Sample inference call for resume_from returned invalid results, please check the model configuration."
)
# check if the inference response dictionary contains the same keys as the resume_from file
eventual_keys = set(sample_response_dict.keys()) | set(sample_data_keys)

# in case of resuming from a file that was generated by an older version of the model,
# we let the discrepancy in the reserved keys slide and later set the missing keys to None
match_keys = set(pre_inf_results_df.columns) | set(INFERENCE_RESERVED_NAMES)
match_keys = set(pre_inf_results_df.columns) | set(INFERENCE_RESERVED_NAMES)

if set(eventual_keys) != match_keys:
diff = set(eventual_keys) ^ set(match_keys)
raise ValueError(
Expand Down

0 comments on commit 15d156f

Please sign in to comment.