Source code for finetuner.experiment

from dataclasses import fields
from typing import Any, Dict, List, Optional, Union

from docarray import DocumentArray

from finetuner.callback import EvaluationCallback
from finetuner.client import FinetunerV1Client
from finetuner.constants import (
    BATCH_SIZE,
    CALLBACKS,
    CONFIG,
    CPU,
    CREATED_AT,
    DATA,
    DESCRIPTION,
    EPOCHS,
    EVAL_DATA,
    EXPERIMENT_NAME,
    FREEZE,
    HYPER_PARAMETERS,
    LEARNING_RATE,
    LOSS,
    MINER,
    MODEL,
    MODEL_OPTIONS,
    NAME,
    NUM_WORKERS,
    OPTIMIZER,
    OPTIMIZER_OPTIONS,
    OPTIONS,
    OUTPUT_DIM,
    RUN_NAME,
    SCHEDULER_STEP,
    TRAIN_DATA,
)
from finetuner.hubble import push_data
from finetuner.names import get_random_name
from finetuner.run import Run


[docs]class Experiment: """Class for an experiment. :param client: Client object for sending api requests. :param name: Name of the experiment. :param status: Status of the experiment. :param created_at: Creation time of the experiment. :param description: Optional description of the experiment. """ def __init__( self, client: FinetunerV1Client, name: str, status: str, created_at: str, description: Optional[str] = '', ): self._client = client self._name = name self._status = status self._created_at = created_at self._description = description @property def name(self) -> str: return self._name @property def status(self) -> str: return self._status
[docs] def get_run(self, name: str) -> Run: """Get a run by its name. :param name: Name of the run. :return: A `Run` object. """ run_info = self._client.get_run(experiment_name=self._name, run_name=name) run = Run( name=run_info[NAME], config=run_info[CONFIG], created_at=run_info[CREATED_AT], description=run_info[DESCRIPTION], experiment_name=self._name, client=self._client, ) return run
[docs] def list_runs(self) -> List[Run]: """List every run inside the experiment. :return: List of `Run` objects. """ run_infos = self._client.list_runs(experiment_name=self._name) return [ Run( name=run_info[NAME], config=run_info[CONFIG], created_at=run_info[CREATED_AT], description=run_info[DESCRIPTION], experiment_name=self._name, client=self._client, ) for run_info in run_infos ]
[docs] def delete_run(self, name: str): """Delete a run by its name. :param name: Name of the run. """ self._client.delete_run(experiment_name=self._name, run_name=name)
[docs] def delete_runs(self): """Delete every run inside the experiment.""" self._client.delete_runs(experiment_name=self._name)
[docs] def create_run( self, model: str, train_data: Union[DocumentArray, str], run_name: Optional[str] = None, **kwargs, ) -> Run: """Create a run inside the experiment. :param model: Name of the model to be fine-tuned. :param train_data: Either a `DocumentArray` for training data or a name of the `DocumentArray` that is pushed on Hubble. :param run_name: Optional name of the run. :param kwargs: Optional keyword arguments for the run config. :return: A `Run` object. """ if not run_name: run_name = get_random_name() eval_callback = None callbacks = kwargs[CALLBACKS] if kwargs.get(CALLBACKS) else [] for callback in callbacks: if isinstance(callback, EvaluationCallback): eval_callback = callback train_data, eval_data, query_data, index_data = push_data( experiment_name=self._name, run_name=run_name, train_data=train_data, eval_data=kwargs.get(EVAL_DATA), query_data=eval_callback.query_data if eval_callback else None, index_data=eval_callback.index_data if eval_callback else None, ) if query_data or index_data: eval_callback.query_data = query_data eval_callback.index_data = index_data kwargs[EVAL_DATA] = eval_data config = self._create_config_for_run( model=model, train_data=train_data, experiment_name=self._name, run_name=run_name, **kwargs, ) cpu = kwargs.get(CPU, True) num_workers = kwargs.get(NUM_WORKERS, 4) run_info = self._client.create_run( run_name=run_name, experiment_name=self._name, run_config=config, device='cpu' if cpu else 'gpu', cpus=num_workers, gpus=1, ) run = Run( client=self._client, name=run_info[NAME], experiment_name=self._name, config=run_info[CONFIG], created_at=run_info[CREATED_AT], description=run_info[DESCRIPTION], ) return run
@staticmethod def _create_config_for_run( model: str, train_data: str, experiment_name: str, run_name: str, **kwargs, ) -> Dict[str, Any]: """Create config for a run. :param model: Name of the model to be fine-tuned. :param train_data: Either a `DocumentArray` for training data or a name of the `DocumentArray` that is pushed on Hubble. :param experiment_name: Name of the experiment. :param run_name: Name of the run. :param kwargs: Optional keyword arguments for the run config. :return: Run parameters wrapped up as a config dict. """ callbacks = kwargs[CALLBACKS] if kwargs.get(CALLBACKS) else [] callbacks = [ { NAME: callback.__class__.__name__, OPTIONS: { field.name: getattr(callback, field.name) for field in fields(callback) }, } for callback in callbacks ] return { MODEL: { NAME: model, FREEZE: kwargs.get(FREEZE), OUTPUT_DIM: kwargs.get(OUTPUT_DIM), OPTIONS: kwargs.get(MODEL_OPTIONS) or {}, }, DATA: { TRAIN_DATA: train_data, EVAL_DATA: kwargs.get(EVAL_DATA), NUM_WORKERS: kwargs.get(NUM_WORKERS), }, CALLBACKS: callbacks, HYPER_PARAMETERS: { LOSS: kwargs.get(LOSS), OPTIMIZER: kwargs.get(OPTIMIZER), OPTIMIZER_OPTIONS: {}, MINER: kwargs.get(MINER), BATCH_SIZE: kwargs.get(BATCH_SIZE), LEARNING_RATE: kwargs.get(LEARNING_RATE), EPOCHS: kwargs.get(EPOCHS), SCHEDULER_STEP: kwargs.get(SCHEDULER_STEP), }, EXPERIMENT_NAME: experiment_name, RUN_NAME: run_name, }