import contextlib, io, sys
from types import ModuleType, SimpleNamespace
class _FakeOptunaTrial:
def __init__(self, number, values=None, value=None, params=None, state='COMPLETE'):
self.number, self.values, self.value = number, values, value
self.params, self.state = params or {}, state
class _FakeOptunaStudy:
def __init__(self, kwargs):
self.study_name = kwargs.get('study_name') or 'fake-study'
directions = kwargs.get('directions')
direction = kwargs.get('direction')
self.directions = list(directions) if directions is not None else [direction]
if len(self.directions) > 1:
self._best_trials = [_FakeOptunaTrial(0, values=[1., 2.], params={'x': 1})]
else:
self._best_trials = [_FakeOptunaTrial(0, value=1., params={'x': 1})]
self.trials = self._best_trials
self.best_trial_accessed = False
self.best_trials_accessed = False
def enqueue_trial(self, evaluate): self.enqueued = evaluate
def optimize(self, *args, **kwargs): self.optimize_args = args, kwargs
@property
def best_trial(self):
self.best_trial_accessed = True
if len(self.directions) > 1: raise RuntimeError('single best trial is unavailable')
return self._best_trials[0]
@property
def best_trials(self):
self.best_trials_accessed = True
return self._best_trials
class _FakeOptuna(ModuleType):
def __init__(self):
super().__init__('optuna')
self.created_studies = []
self.samplers = SimpleNamespace(
TPESampler=lambda **kwargs: ('tpe', kwargs),
GridSampler=lambda search_space: ('grid', search_space),
RandomSampler=lambda **kwargs: ('random', kwargs),
)
self.trial = SimpleNamespace(TrialState=SimpleNamespace(PRUNED='PRUNED', COMPLETE='COMPLETE'))
self.visualization = SimpleNamespace(
plot_optimization_history=lambda study: None,
plot_param_importances=lambda study: None,
plot_slice=lambda study: None,
plot_parallel_coordinate=lambda study: None,
)
def create_study(self, **kwargs):
self.created_studies.append(kwargs)
return _FakeOptunaStudy(kwargs)
@contextlib.contextmanager
def _use_fake_optuna(fake):
old_optuna = sys.modules.get('optuna')
sys.modules['optuna'] = fake
try: yield fake
finally:
if old_optuna is None: sys.modules.pop('optuna', None)
else: sys.modules['optuna'] = old_optuna
fake = _FakeOptuna()
with _use_fake_optuna(fake), contextlib.redirect_stdout(io.StringIO()):
study = run_optuna_study(lambda trial: (1., 2.), study_type='random', direction=['minimize', 'minimize'], n_trials=1, save_study=False, show_plots=False, show_progress_bar=False)
created_study = fake.created_studies[0]
# Multi-objective Optuna studies must be created with directions= and reported with best_trials.
assert created_study.get('directions') == ['minimize', 'minimize']
assert 'direction' not in created_study
assert study.best_trials_accessed
assert not study.best_trial_accessed
old_joblib_load = joblib.load
try:
resumed_study = _FakeOptunaStudy({'directions': ['minimize', 'minimize']})
joblib.load = lambda resume: resumed_study
fake = _FakeOptuna()
with _use_fake_optuna(fake), contextlib.redirect_stdout(io.StringIO()):
study = run_optuna_study(lambda trial: (1., 2.), resume='fake-study.pkl', n_trials=1, save_study=False, show_plots=False, show_progress_bar=False)
assert study is resumed_study
assert study.best_trials_accessed
assert not study.best_trial_accessed
# Test resume with empty/no completed trials
resumed_empty_study = _FakeOptunaStudy({'directions': ['minimize', 'minimize']})
resumed_empty_study._best_trials = []
resumed_empty_study.trials = []
joblib.load = lambda resume: resumed_empty_study
fake = _FakeOptuna()
with _use_fake_optuna(fake), contextlib.redirect_stdout(io.StringIO()) as trapped:
study = run_optuna_study(lambda trial: (1., 2.), resume='fake-study-empty.pkl', n_trials=1, save_study=False, show_plots=False, show_progress_bar=False)
assert study is resumed_empty_study
assert "No finished trials yet." in trapped.getvalue()
finally:
joblib.load = old_joblib_load
fake = _FakeOptuna()
with _use_fake_optuna(fake), contextlib.redirect_stdout(io.StringIO()):
study = run_optuna_study(lambda trial: 1., study_type='random', direction='minimize', n_trials=1, save_study=False, show_plots=False, show_progress_bar=False)
created_study = fake.created_studies[0]
assert created_study.get('direction') == 'minimize'
assert 'directions' not in created_study
assert study.best_trial_accessed