DUE TO SPAM, SIGN-UP IS DISABLED. Goto Selfserve wiki signup and request an account.
Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).
Motivation
FLIP-437: Support ML Models in Flink SQL introduced ML_PREDICT and ML_EVALUATE functions, this FLIP adds details on how they can be implemented in both SQL and Table API as well as the new interfaces to introduce.
FLIP-437 proposed syntax:
SELECT f1, f2, label FROM ML_PREDICT(`my_data`, `classifier_model`, DESCRIPTOR(f1, f2))
SELECT * FROM ML_EVALUATE(`eval_data`, `classifier_model`, DESCRIPTOR(f1, f2))
after researching the code, we decide to still use TABLE and MODEL keyword to indicate table and model identifier. So the syntax we will implement will be
Prediction
SELECT f1, f2, label FROM ML_PREDICT(TABLE `my_data`, MODEL `classifier_model`, DESCRIPTOR(f1, f2))
Evaluation
For ml_evaluate function, this proposal separates the label and feature Descriptor to make it more clear.
SELECT * FROM ML_EVALUATE(TABLE `eval_data`, MODEL `classifier_model`, DESCRIPTOR(label), DESCRIPTOR(f1, f2))
SELECT * FROM ML_EVALUATE(input => TABLE `eval_data`, model_input => MODEL `classifier_model`, label => DESCRIPTOR(label), features => DESCRIPTOR(f1, f2), task => 'classification')
The task parameter can explicitly override the type of model to be evaluated.
Additionally, each predict and evaluate function can take an optional MAP runtime config which can configure how the function can be run such as async or sync .
Public Interfaces
Model runtime
Below is a class diagram overview to introduce provider and runtime for model prediction:
ML_PREDICT function can have different implementations based on Model provider similar to Table connector. For example:
CREATE MODEL mymodel
INPUT (text STRING)
OUTPUT (result STRING)
WITH (
'provider' = 'openai',
'task' = 'text_generation',
'endpoint' = 'https://api.openai.com/v1/llm/v1/chat',
'apikey' = 'abcdefg',
'system_prompt' = 'translate text to english'
)
With provider being openai , this model can use runtime provided by openai. provider option is mandatory to construct ModelProviderFactory through Java SPI.
ModelProviderFactory and ModelProvider are introduced as below to choose runtime based on provider:
/**
* Base interface for configuring a model provider.
*/
@PublicEvolving
public interface ModelProviderFactory extends Factory {
/**
* Create ModelRuntime based on connector.
*/
ModelProvider createModelProvider(Context context);
/** Provides catalog and session information describing the model to be accessed. */
@PublicEvolving
interface Context {
/**
* Returns the identifier of the model in the {@link Catalog}.
*
* <p>This identifier describes the relationship between the model instance and the
* associated {@link Catalog} (if any).
*/
ObjectIdentifier getObjectIdentifier();
/**
* Returns the resolved model information received from the {@link Catalog} or persisted
* plan.
*
* <p>The {@link ResolvedCatalogModel} forwards the metadata from the catalog but offers a
* validated {@link ResolvedSchema}. The original metadata object is available via {@link
* ResolvedCatalogModel#getOrigin()}.
*
* <p>In most cases, a factory is interested in the following characteristics:
*
* <pre>{@code
* // get the physical input and output data type to initialize the connector
* context.getCatalogModel().getResolvedInputSchema().toPhysicalRowDataType()
* context.getCatalogModel().getResolvedInputSchema().toPhysicalRowDataType()
*
* // get configuration options
* context.getCatalogModel().getOptions()
* }</pre>
*
* <p>During a plan restore, usually the model information persisted in the plan is used to
* reconstruct the catalog model.
*/
ResolvedCatalogModel getCatalogModel();
/** Gives read-only access to the configuration of the current session. */
ReadableConfig getConfiguration();
/**
* Returns the class loader of the current session.
*
* <p>The class loader is in particular useful for discovering further (nested) factories.
*/
ClassLoader getClassLoader();
/** Whether the model is temporary. */
boolean isTemporary();
}
}
/**
* Model Provider base interface.
*/
interface ModelProvider {
/**
* Creates a copy of this instance during planning. The copy should be a deep copy of all
* mutable members.
*/
ModelProvider copy();
/**
* Context for creating runtime providers.
*/
interface Context {
/** Resolved catalog model. */
ResolvedCatalogModel getCatalogModel();
/** Runtime config provided to provider */
ReadableConfig runtimeConfig();
}
}
/**
* Provider to provide synchronous model runtime for predict function.
*/
interface PredictRuntimeProvider extends ModelProvider {
/**
* Returns a ml_predict function.
*/
PredictFunction createPredictFunction(Context context);
}
/**
* Provider to provide asynchronous model runtime for predict function.
*/
interface AsyncPredictRuntimeProvider extends ModelProvider {
/**
* Returns a async ml_predict function.
*/
AsyncPredictFunction createAsyncPredictFunction(Context context);
}
/**
* A wrapper class of {@link TableFunction} for synchronously prediction.
*
* <p>The output type of this table function is fixed as {@link Row}.
*/
@PublicEvolving
public abstract class PredictFunction extends TableFunction<RowData> {
/**
* Synchronously predict result based on input row.
*
* @param inputRow - A {@link RowData} that wraps input for predict function.
*
* @return A collection of predicted results.
*/
public abstract Collection<RowData> predict(RowData inputRow);
/** Invoke {@link #predict} and handle exceptions. */
public final void eval(Object... args) {
try {
GenericRowData argsData = GenericRowData.of(args);
Collection<RowData> results = predict(argsData);
if (results == null) {
return;
}
results.forEach(this::collect);
} catch (Exception e) {
throw new FlinkRuntimeException("Prediction error", e);
}
}
}
/**
* A wrapper class of {@link AsyncTableFunction} for asynchronously prediction.
*
* <p>The output type of this table function is fixed as {@link RowData}.
*/
@PublicEvolving
public abstract class AsyncPredictFunction extends AsyncTableFunction<RowData> {
/**
* Asynchronously predict result based on input row.
*
* @param inputRow - A {@link RowData} that wraps input for predict function.
* @param completableFuture - Future to collect prediction results.
*/
public abstract CompletableFuture<Collection<RowData>> asyncPredict(RowData inputRow);
public void eval(CompletableFuture<Collection<Row>> completableFuture, Object... args) {
try {
GenericRowData argsData = GenericRowData.of(args);
asyncPredict(keyRow)
.whenComplete(
(result, exception) -> {
if (exception != null) {
future.completeExceptionally(
new TableException(
"Failed to execute asynchronously prediction.",
exception));
return;
}
completableFuture.complete(result);
});
}
}
}
interface Module {
default Optional<ModelProviderFactory> getModelProviderFactory() {
return Optional.empty();
}
}
Async Predict Function Usage and Config
AsyncPredictFuntion can be used if a ModelProvider implements AsyncPredictRuntimeProvider . To enable async function, user needs to pass async:true in runtime config:
SELECT * FROM ML_PREDICT(TABLE input, MODEL mdl, descriptor(f1, f2), MAP['async', 'true']);
If the runtime config map is not provided, sync function will be used by default.
Additional configs for async function:
Name | Meaning |
|---|---|
max-concurrent-operations | The max number of async i/o operation that the async predict can trigger |
timeout | The total time which can pass before the invocation (including retries) is considered timed out and task execution is failed |
retry-strategy | FIXED_DELAY is for a retry after a fixed amount of time |
retry-delay | The time to wait between retries for the FIXED_DELAY strategy. Could be the base delay time for a (not yet proposed) exponential backoff. |
max-attempts | The maximum number of attempts while retrying. |
output-mode | Output mode for async operation. Can be |
async | Whether run async predict function or not. Default to false. |
The defaults for these configs are
table.exec.async-ml-predict.max-concurrent-operations: 10 table.exec.async-ml-predict.timeout: 30s table.exec.async-ml-predict.retry-strategy: FIXED_DELAY table.exec.async-ml-predict.fixed-delay: 10s table.exec.async-ml-predict.max-attempts: 3
ML_EVALUATE function
For ML_EVALUATE , we don't provide an interface for user to implement it because user can easily use predicted value to calculate evaluation themselves. Instead, Flink provides an default implementation for ML_EVALUATE based on task option. The output type of evaluation function will be a MAP and the keys can be different metrics according to task type.
Output metrics based on task type
| Task | Output metrics |
|---|---|
| classification | Accuracy Precision Recall f1 score |
| clustering | Mean Davies-Bouldin Index |
| embedding | Mean Cosine Similarity |
| regression | Mean Absolute Error Mean Squared Error Root Mean Squared Error Mean Absolute Percentage Error |
| text_generation | BLEU Score ROUGE Score Semantic Similarity |
Model Sql Table Functions
SqlMLPredictionTableFunction and SqlMLEvaluateTableFunction will be registered as builtin SqlFunction which provides type inference and validation. We can create a base class SqlMLTableFunction for them to indicate ML table functions in the planning phase.
public abstract class SqlMlFunctionTableFunction extends SqlFunction implements SqlTableFunction {
@Override
public SqlReturnTypeInference getRowTypeInference() {
return this::inferRowType;
}
protected abstract RelDataType inferRowType(SqlOperatorBinding opBinding);
}
public class SqlMlPredictTableFunction extends SqlMlFunctionTableFunction {
// Input/output validation and type inference
}
public class SqlMlEvaluateTableFunction extends SqlMlFunctionTableFunction {
// Input/output validation and type inference
}
Proposed Changes
Sql syntax and planner changes for sql execution
MODELkeyword will be introduced in sql parser to parse model identifier toSqlModelCallwhich has the model identifier. When validating the call node, we can get and store the model for type inference etc. This is slightly different from the FLIP where there’s noMODELorTABLEkeyword in the function parameters. The result is it’s difficult to create Model and Table node from parser without these keywords.We need to extend
FlinkCalciteCatalogReaderto pass inCatalogManagerand providegetModelfunction.We need to translate
SqlModelCalltoRexModelCallnode inSqlToRelConverterduring which processCatalogModelwill be translated toModelProviderinstance which containsResolvedCatalogModelWe need to introduce physical planner rule
StreamPhysicalModelTableFunctionRule,BatchPhysicalModelTableFunctionRuleto translate model table function call to ExecNode- During the translation, we will look at the functions (ML_PREDICT/ML_EVALUATE or some other function later on) and create Corresponding ExecNode
- In ExecNode, we need to create FunctionProvider based on Context (Model options, runtime config whether it's async or not etc)
Compatibility, Deprecation, and Migration Plan
There’s compatibility issue since these are new features.
If we plan to support
ML_PREDICT/ML_EVALUATEin PTF later, compiled plan won’t be affected as long as ExecNodes are there. We can change the planner to not support old syntax (such as drop Model keyword and model sql and rex node.
Future work
More generic support of Model argument in PTF and PTF Table API so user can define how Model can be used flexibly in their own UDF.
Test Plan
All existing tests pass
Add new unit tests and integration tests for any new code changes
Rejected Alternatives
None
