diff --git a/pyotb/apps.py b/pyotb/apps.py index 9bf67599319b7b1fb1fa215f85f86fe3bd92626c..5cc758a68986f50e2cbd951858e8dc508473b830 100644 --- a/pyotb/apps.py +++ b/pyotb/apps.py @@ -83,40 +83,57 @@ class {name}(App): def __init__(self, *args, **kwargs): super().__init__('{name}', *args, **kwargs) """ + + +class OTBTFApp(App): + """ + Helper for OTBTF + """ + @staticmethod + def set_nb_sources(*args, n_sources=None): + """Set the number of sources of TensorflowModelServe. Can be either user-defined or deduced from the args + + :param args: arguments (dict). NB: we don't need kwargs because it cannot contain source#.il + :param n_sources: number of sources. Default is None (resolves the number of sources based on the + content of the dict passed in args, where some 'source' str is found) + """ + if n_sources: + os.environ['OTB_TF_NSOURCES'] = str(int(n_sources)) + else: + # Retrieving the number of `source#.il` parameters + params_dic = {k: v for arg in args if isinstance(arg, dict) for k, v in arg.items()} + n_sources = len([k for k in params_dic if 'source' in k and k.endswith('.il')]) + if n_sources >= 1: + os.environ['OTB_TF_NSOURCES'] = str(n_sources) + + def __init__(self, app_name, *args, n_sources=None, **kwargs): + """ + :param args: args + :param n_sources: number of sources. Default is None (resolves the number of sources based on the + content of the dict passed in args, where some 'source' str is found) + :param kwargs: kwargs + """ + self.set_nb_sources(*args, n_sources=n_sources) + super().__init__(app_name, *args, **kwargs) + + for _app in AVAILABLE_APPLICATIONS: # Default behavior for any OTB application exec(_code_template.format(name=_app)) # pylint: disable=exec-used - # Customize the behavior for TensorflowModelServe application. The user doesn't need to set the env variable + # Customize the behavior for some OTBTF applications. The user doesn't need to set the env variable # `OTB_TF_NSOURCES`, it is handled in pyotb if _app == 'TensorflowModelServe': - class TensorflowModelServe(App): - """ - Helper for OTBTF - """ - @staticmethod - def set_nb_sources(*args, n_sources=None): - """ - Set the number of sources of TensorflowModelServe. Can be either user-defined or deduced from the args - :param args: arguments - :param n_sources: number of sources. Default is None (resolves the number of sources based on the - content of the dict passed in args, where some 'source' str is found) - """ - if n_sources: - os.environ['OTB_TF_NSOURCES'] = str(int(n_sources)) - else: - # Retrieving the number of `source#.il` parameters - params_dic = {k: v for arg in args if isinstance(arg, dict) for k, v in arg.items()} - n_sources = len([k for k in params_dic if k.startswith('source') and k.endswith('.il')]) - if n_sources >= 1: - os.environ['OTB_TF_NSOURCES'] = str(n_sources) + class TensorflowModelServe(OTBTFApp): + def __init__(self, *args, n_sources=None, **kwargs): + super().__init__('TensorflowModelServe', *args, n_sources=n_sources, **kwargs) + + elif _app == 'PatchesExtraction': + class PatchesExtraction(OTBTFApp): + def __init__(self, *args, n_sources=None, **kwargs): + super().__init__('PatchesExtraction', *args, n_sources=n_sources, **kwargs) + elif _app == 'TensorflowModelTrain': + class TensorflowModelTrain(OTBTFApp): def __init__(self, *args, n_sources=None, **kwargs): - """ - :param args: args - :param n_sources: number of sources. Default is None (resolves the number of sources based on the - content of the dict passed in args, where some 'source' str is found) - :param kwargs: kwargs - """ - self.set_nb_sources(*args, n_sources=n_sources) - super().__init__('TensorflowModelServe', *args, **kwargs) + super().__init__('TensorflowModelTrain', *args, n_sources=n_sources, **kwargs)