Skip to content

Commit

Permalink
add lora_soup evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
shuishen112 committed Dec 30, 2024
1 parent 6321495 commit d5a8728
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions gsm_evaluator_with_lora_soup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mttl.evaluators.rouge_evaluator import RougeEvaluator
from mttl.models.containers.selectors.base import UniformSelectorConfig
from mttl.arguments import EvaluationConfig, ExpertConfig
from mttl.models.lightning.expert_module import ExpertModule
from mttl.models.lightning.expert_module import ExpertModule, MultiExpertModule
import torch
from mttl.logging import setup_logging

Expand All @@ -28,8 +28,11 @@
datamodule = get_datamodule(args, for_generation=True)
evaluator = GsmEvaluator(datamodule)

#
module = ExpertModule(**vars(args)).to(device)
if args.library_id is None:
module = ExpertModule(**vars(args)).to(device)
else:
module = MultiExpertModule(**vars(args)).to("cuda")
module.add_experts_from_library(args.library_id)

if args.checkpoint is not None:
checkpoint = torch.load(args.checkpoint, weights_only=False)["state_dict"]
Expand Down

0 comments on commit d5a8728

Please sign in to comment.