Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).
Motivation
FLIP-39 rebuilds the Flink ML pipeline on top of TableAPI and introduces a new set of Java APIs. As Python is widely used in ML areas, providing Python ML Pipeline APIs for Flink can not only make it easier to write ML jobs for Python users but also broaden the adoption of Flink ML. This design document discusses the design of Python ML Pipeline API.
Goals
- Add Python pipeline API according to Java pipeline API(we will adapt the Python pipeline API if Java pipeline API changes).
- Support native Python Transformer/Estimator/Model, i.e., users can write not only Python Transformer/Estimator/Model wrappers for calling Java ones but also can write native Python Transformer/Estimator/Models.
- Ease of use. Support keyword arguments when defining parameters.
Proposed Changes
Java Pipeline API and the related Python API
To give it a whole picture, the table below lists all key Java interfaces and the corresponding Python interfaces to be added.
Interface type | Java Interface Name | Python Interface Name | Description |
ML core interface | PipelineStage | PipelineStage | The base node of Pipeline |
Transformer | Transformer | Native Python interface | |
JavaTransformer | Python wrappers for calling Java interface | ||
Estimator | Estimator | Native Python interface | |
JavaEstimator | Python wrappers for calling Java interface | ||
Model | Model | Native Python interface | |
JavaModel | Python wrappers for calling Java interface | ||
ML Pipeline | Pipeline | Pipeline | Describes a ML workflow |
ML environment | MLEnvironment | MLEnvironment | Stores the necessary context in Flink |
MLEnvironmentFactory | MLEnvironmentFactory | Factory to get the MLEnvironment | |
Help interface | Params | Params | A container of parameters |
ParamInfo | ParamInfo | Definition of a parameter | |
WithParams | WithParams | common interface to interact with classes with parameters |
Support native Python Transformer/Estimator/Model
Transformers, Estimator and Model can be treated as an algorithm. With different Transformers, Estimators and Models, users can write different machine learning jobs.
There are ways to provide Python Transformers, Estimators and Models. On one hand, we can provide Python Transformer/Estimator/Model that wraps existing Java ones to leverage the power of Java APIs. While on the other hand, there are scenarios that Python users want to write a native Python Transformer/Estimator/Model. And this should also be supported.
To support both these two cases, we provided two kinds of interfaces for Transformer/Estimator/Model. One for Java wrappers, the other for native Python APIs.
Java Interface Name | Python Interface Name |
Transformer | Transformer |
JavaTransformer | |
Estimator | Estimator |
JavaEstimator | |
Model | Model |
JavaModel |
Below, we will take Transformer as an example to show you what the interfaces would be like and how to implement these two kinds of interfaces.
class Transformer(PipelineStage): """ A transformer is a PipelineStage that transforms an input Table to a result Table. """ __metaclass__ = ABCMeta @abstractmethod def transform(self, table_env, table): raise NotImplementedError() class JavaTransformer(Transformer): """ Base class for :py:class:`Transformer`s that wrap Java implementations. Subclasses should ensure they have the transformer Java object available as j_obj. """ def __init__(self, j_obj): super().__init__() self._j_obj = j_obj def transform(self, table_env, table): self._convert_params_to_java(self._j_obj) return Table(self._j_obj.transform(table_env._j_tenv, table._j_table)) |
To write native Python Transformers, you can extend the Transformer class and implement the transform() method. For Java wrappers, you can extend the JavaTransformer class and pass the java object to the constructor. In this case, you don’t have to implement the transform() method it has already been implemented.
Support keyword arguments
Parameters are widely used in machine learning algorithms. For Java APIs, builder pattern is used to set params:
// Java
KMeans kMeans = new KMeans()
.setVectorCol("features")
.setK(2)
.setPredictionCol("prediction_result");
To align the APIs, the builder pattern will also be supported in Python.
// Python, builder pattern
kmeans = KMeans() \
.set_vector_col("features") \
.set_k(2) \
.set_prediction_col("prediction_result")
But, builder pattern is not commonly used in Python because the language takes keyword arguments. To make it easy for Python users, we propose to also support keyword arguments for the constructor:
// Python, keyword arguments
kmeans = KMeans(vector_col="features", k=2, prediction_col="prediction_result")
Modules
Two more python packages will be added under the current pyflink, i.e., `ml` and `mllib`. Package `ml` is used to place classes that align with java flink-ml-api module and package `mllib` is used to place classes that align with java flink-ml-lib module.
flink-python(maven module)
- pyflink(python package)
- ml
- mllib
As for components in FLINK JIRA, we can reuse the current `API/Python` component tag.
Public interfaces
Below the detailed design for all Python interfaces are listed.
ML core interface
PipelineStage
class PipelineStage(WithParams): """ Base class for a stage in a pipeline. """ def __init__(self, params=None): if params is None: self._params = Params() else: self._params = params def get_params(self): return self._params |
Transformer
class Transformer(PipelineStage): """ A transformer is a PipelineStage that transforms an input Table to a result Table. """ __metaclass__ = ABCMeta @abstractmethod def transform(self, table_env, table): """ Applies the transformer on the input table, and returns the result table. :param table_env: the table environment to which the input table is bound. :param table: the table to be transformed :returns: the transformed table """ raise NotImplementedError() |
JavaTransformer
class JavaTransformer(Transformer): """ Base class for :py:class:`Transformer`s that wrap Java implementations. Subclasses should ensure they have the transformer Java object available as j_obj. """ def __init__(self, j_obj): super().__init__() self._j_obj = j_obj def transform(self, table_env, table): """ Applies the transformer on the input table, and returns the result table. :param table_env: the table environment to which the input table is bound. :param table: the table to be transformed :returns: the transformed table """ self._convert_params_to_java(self._j_obj) return Table(self._j_obj.transform(table_env._j_tenv, table._j_table)) |
Estimator
class Estimator(PipelineStage): """ Estimators are PipelineStages responsible for training and generating machine learning models. The implementations are expected to take an input table as training samples and generate a Model which fits these samples. """ __metaclass__ = ABCMeta def fit(self, table_env, table): """ Train and produce a Model which fits the records in the given Table. :param table_env: the table environment to which the input table is bound. :param table: the table with records to train the Model. :returns: a model trained to fit on the given Table. """ raise NotImplementedError() |
JavaEstimator
class JavaEstimator(Estimator): """ Base class for :py:class:`Estimator`s that wrap Java implementations. Subclasses should ensure they have the estimator Java object available as j_obj. """ def __init__(self, j_obj): super().__init__() self._j_obj = j_obj def fit(self, table_env, table): """ Train and produce a Model which fits the records in the given Table. :param table_env: the table environment to which the input table is bound. :param table: the table with records to train the Model. :returns: a model trained to fit on the given Table. """ self._convert_params_to_java(self._j_obj) return JavaModel(self._j_obj.fit(table_env._j_tenv, table._j_table)) |
Model
class Model(Transformer): """ Abstract class for models that are fitted by estimators. A model is an ordinary Transformer except how it is created. While ordinary transformers are defined by specifying the parameters directly, a model is usually generated by an Estimator when Estimator.fit(table_env, table) is invoked. """ __metaclass__ = ABCMeta |
JavaModel
class JavaModel(JavaTransformer, Model): """ Base class for :py:class:`JavaTransformer`s that wrap Java implementations. Subclasses should ensure they have the model Java object available as j_obj. """ def __init__(self, j_obj): super().__init__(j_obj) |
ML Pipeline
class Pipeline(Estimator, Model): """ A pipeline is a linear workflow which chains Estimators and Transformers to execute an algorithm. """ def __init__(self, stages=None): super().__init__() self.stages = [] if stages is not None: self.stages = stages self.last_estimator_index = -1 def _need_fit(self): return self.last_estimator_index >= 0 @staticmethod def _is_stage_need_fit(stage): return (isinstance(stage, Pipeline) and stage._need_fit()) or \ ((not isinstance(stage, Pipeline)) and isinstance(stage, Estimator)) def append_stage(self, stage): if self._is_stage_need_fit(stage): self.last_estimator_index = len(self.stages) elif not isinstance(stage, Transformer): raise RuntimeError("All PipelineStages should be Estimator or Transformer!") self.stages.append(stage) return self def fit(self, t_env, input): """ Train the pipeline to fit on the records in the given Table. :param t_env: the table environment to which the input table is bound. :param input: the table with records to train the Pipeline. :returns: a pipeline with same stages as this Pipeline except all Estimators replaced with their corresponding Models. """ transform_stages = [] for i in range(0, len(self.stages)): s = self.stages[i] if i <= self.last_estimator_index: need_fit = self._is_stage_need_fit(s) if need_fit: t = s.fit(t_env, input) else: t = s transform_stages.append(t) input = t.transform(t_env, input) else: transform_stages.append(s) return Pipeline(transform_stages) def transform(self, t_env, input): """ Generate a result table by applying all the stages in this pipeline to the input table in order. :param t_env: the table environment to which the input table is bound. :param input: the table to be transformed. :returns: a result table with all the stages applied to the input tables in order. """ if self._need_fit(): raise RuntimeError("Pipeline contains Estimator, need to fit first.") for s in self.stages: input = s.transform(t_env, input) return input |
ML environment
MLEnvironmentFactory
class MLEnvironmentFactory: """ Factory to get the MLEnvironment using a MLEnvironmentId. """ _lock = threading.RLock() _default_ml_environment_id = 0 _next_id = 1 _map = {} gateway = get_gateway() j_ml_env = gateway.jvm.MLEnvironmentFactory.getDefault() _default_ml_env = MLEnvironment( ExecutionEnvironment(j_ml_env.getExecutionEnvironment()), StreamExecutionEnvironment(j_ml_env.getStreamExecutionEnvironment()), BatchTableEnvironment(j_ml_env.getBatchTableEnvironment()), StreamTableEnvironment(j_ml_env.getStreamTableEnvironment())) _map[_default_ml_environment_id] = _default_ml_env @staticmethod def get(ml_env_id): """ Get the MLEnvironment using a MLEnvironmentId. :param ml_env_id: the MLEnvironmentId :return: the MLEnvironment """ with MLEnvironmentFactory._lock: if ml_env_id not in MLEnvironmentFactory._map: raise ValueError( "Cannot find MLEnvironment for MLEnvironmentId %s. " "Did you get the MLEnvironmentId by calling " "get_new_ml_environment_id?" % ml_env_id) return MLEnvironmentFactory._map[ml_env_id] @staticmethod def get_default(): """ Get the MLEnvironment use the default MLEnvironmentId. :return: the default MLEnvironment. """ with MLEnvironmentFactory._lock: return MLEnvironmentFactory._map[MLEnvironmentFactory._default_ml_environment_id] @staticmethod def get_new_ml_environment_id(): """ Create a unique MLEnvironment id and register a new MLEnvironment in the factory. :return: the MLEnvironment id. """ with MLEnvironmentFactory._lock: return MLEnvironmentFactory.register_ml_environment(MLEnvironment()) @staticmethod def register_ml_environment(ml_environment): """ Register a new MLEnvironment to the factory and return a new MLEnvironment id. :param ml_environment: the MLEnvironment that will be stored in the factory. :return: the MLEnvironment id. """ with MLEnvironmentFactory._lock: MLEnvironmentFactory._map[MLEnvironmentFactory._next_id] = ml_environment MLEnvironmentFactory._next_id += 1 return MLEnvironmentFactory._next_id - 1 @staticmethod def remove(ml_env_id): """ Remove the MLEnvironment using the MLEnvironmentId. :param ml_env_id: the id. :return: the removed MLEnvironment """ with MLEnvironmentFactory._lock: if ml_env_id is None: raise ValueError("The environment id cannot be null.") # Never remove the default MLEnvironment. Just return the default environment. if MLEnvironmentFactory._default_ml_env == ml_env_id: return MLEnvironmentFactory.get_default() else: return MLEnvironmentFactory._map.pop(ml_env_id) |
MLEnvironment
class MLEnvironment(object): """ The MLEnvironment stores the necessary context in Flink. Each MLEnvironment will be associated with a unique ID. The operations associated with the same MLEnvironment ID will share the same Flink job context. Both MLEnvironment ID and MLEnvironment can only be retrieved from MLEnvironmentFactory. """ def __init__(self, exe_env=None, stream_exe_env=None, batch_tab_env=None, stream_tab_env=None): self._exe_env = exe_env self._stream_exe_env = stream_exe_env self._batch_tab_env = batch_tab_env self._stream_tab_env = stream_tab_env def get_execution_environment(self): if self._exe_env is None: self._exe_env = ExecutionEnvironment.get_execution_environment() return self._exe_env def get_stream_execution_environment(self): if self._stream_exe_env is None: self._stream_exe_env = StreamExecutionEnvironment.get_execution_environment() return self._stream_exe_env def get_batch_table_environment(self): if self._batch_tab_env is None: self._batch_tab_env = BatchTableEnvironment.create(ExecutionEnvironment.get_execution_environment()) return self._batch_tab_env def get_stream_table_environment(self): if self._stream_tab_env is None: self._stream_tab_env = StreamTableEnvironment.create(StreamExecutionEnvironment.get_execution_environment()) return self._stream_tab_env |
Params interface
Params
class Params(object): """ The map-like container class for parameter. This class is provided to unify the interaction with parameters. """ def __init__(self): self._paramMap = {} def set(self, k, v): self._paramMap[k] = v def get(self, k): return self._paramMap[k] |
ParamInfo
class ParamInfo(object): """ Definition of a parameter, including name, description, type_converter and so on. """ def __init__(self, name, description, type_converter=None): self.name = str(name) self.description = str(description) self.type_converter = TypeConverters.identity if type_converter is None else type_converter |
WithParams
class WithParams(object): """ Parameters are widely used in machine learning realm. This class defines a common interface to interact with classes with parameters. """ def get_params(self): pass def set(self, k, v): self.get_params().set(k, v) return self def get(self, k): return self.get_params().get(k) def _set(self, **kwargs): """ Sets user-supplied params. """ for param, value in kwargs.items(): p = getattr(self, param) if value is not None: try: value = p.type_converter(value) except TypeError as e: raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e)) self.get_params().set(p, value) return self |
Example
trainTable = t_env.from_path('traningSource') servingTable = t_env.from_path('servingSource') # transformer va = VectorAssembler(selected_cols=["a", "b"], output_col="features") # estimator kmeans = KMeans()\ .set_vector_col("features")\ .set_k(2)\ .set_reserved_cols(["a", "b"])\ .set_prediction_col("prediction_result")\ .set_max_iter(100) # pipeline pipeline = Pipeline().append_stage(va).append_stage(kmeans) pipeline\ .fit(t_env, trainTable)\ .transform(t_env, servingTable)\ .insert_into('mySink') t_env.execute('KmeansTest') |
Implementation Plan
- Align interface for MLEnvironment and MLEnvironmentFactory
- Add support for Python Translator/Estimator/Model
- Add support for Translator/Estimator/Model Java Wrappers.