根据近期的github方案,实现对txt格式的pmml文件的加载
添加依赖
<dependency><groupId>org.jpmml</groupId><artifactId>pmml-lightgbm</artifactId><version>1.5.4</version>
</dependency>
<dependency><groupId>org.jpmml</groupId><artifactId>pmml-evaluator</artifactId><version>1.6.6</version>
</dependency>
<dependency><groupId>org.jpmml</groupId><artifactId>pmml-model</artifactId><version>1.6.6</version>
</dependency>
工具类
import lombok.extern.slf4j.Slf4j;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorBuilder;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetField;
import org.jpmml.lightgbm.GBDT;
import org.jpmml.lightgbm.HasLightGBMOptions;
import org.jpmml.lightgbm.LightGBMUtil;
import org.jpmml.model.metro.MetroJAXBUtil;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;/*** 加载、初始化 PMML模型文件 :* 依赖 pmml-lightgbm-1.5.0(AGPL-3.0 License)* <p>* 解析PMML文件 @link https://github.com/jpmml/jpmml-lightgbm* 生成evaluator @link https://github.com/jpmml/jpmml-evaluator*/
@Slf4j
public class LightgbmTxtInitializer {// description = "Custom objective function"private static String objectiveFunction = null;// description = "Transform LightGBM-style trees to PMML-style trees",private static boolean compact = true;// description = "Treat Not-a-Number (NaN) values as missing values",private static boolean nanAsMissing = true;// description = "Limit the number of trees. Defaults to all trees"private static Integer numIteration = null;// description = "Target name. Defaults to \"_target\""private static String targetName = null;// description = "Target categories. Defaults to 0-based index [0, 1, .., num_class - 1]"private static List<String> targetCategories = null;public static void main(String[] output) throws Exception {Resource resource = new ClassPathResource("lightgbm_model.txt");InputStream pmmlFileInputStream = resource.getInputStream();// 生成模型执行器ModelEvaluator evaluator = initEvaluator(pmmlFileInputStream);// 打印特征参数List<InputField> inputFields = evaluator.getInputFields();log.info("ModelEvaluator featureNames:" + inputFields);// 调试执行预测Map<String, Number> waitPreSample = new HashMap<>(8);waitPreSample.put("0", 0.1);waitPreSample.put("1", 0.2);waitPreSample.put("2", 0.3);String predictedValue = getPredictedValue(waitPreSample, evaluator);pmmlFileInputStream.close();}public static ModelEvaluator initEvaluator(InputStream pmmlFileInputStream) throws Exception {GBDT gbdt;long begin = System.currentTimeMillis();gbdt = LightGBMUtil.loadGBDT(pmmlFileInputStream);log.info("Loaded GBDT in {} ms.", (System.currentTimeMillis() - begin));if (objectiveFunction != null) {log.info("Setting custom objective function");gbdt.setObjectiveFunction(LightGBMUtil.parseObjectiveFunction(objectiveFunction));}Map<String, Object> options = new LinkedHashMap<>();options.put(HasLightGBMOptions.OPTION_COMPACT, compact);options.put(HasLightGBMOptions.OPTION_NAN_AS_MISSING, nanAsMissing);options.put(HasLightGBMOptions.OPTION_NUM_ITERATION, numIteration);// 生成标准PMMLbegin = System.currentTimeMillis();PMML pmml;pmml = gbdt.encodePMML(options, targetName, targetCategories);long end = System.currentTimeMillis();log.info("Converted GBDT to PMML in {} ms.", (System.currentTimeMillis() - begin));// no need// 输出PMML格式文件begin = System.currentTimeMillis();File outputFile = new File("E://t.pmml");OutputStream os = new FileOutputStream(outputFile);MetroJAXBUtil.marshalPMML(pmml, os);log.info("Marshalled PMML in {} ms.", (System.currentTimeMillis() - begin));// 生成evaluatorbegin = System.currentTimeMillis();ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();ModelEvaluatorBuilder modelEvaluatorBuilder = new ModelEvaluatorBuilder(pmml);modelEvaluatorBuilder.setModelEvaluatorFactory(modelEvaluatorFactory);ModelEvaluator<?> evaluator = modelEvaluatorBuilder.build();evaluator.verify();log.info("Init evaluator in {} ms.", (System.currentTimeMillis() - begin));return evaluator;}public static String getPredictedValue(Map<String, ?> argumentMap,ModelEvaluator<?> evaluator) {// 预测计算Map<String, ?> evaluateResult = evaluator.evaluate(argumentMap);log.info("evaluateResult:" + evaluateResult);// 提取预测结果String predictedValue = null;TargetField targetFieldName = evaluator.getTargetField();Object targetFieldValue = evaluateResult.get(targetFieldName.getFieldName());// 输出预测结果if (targetFieldValue instanceof ProbabilityDistribution) {predictedValue = ((ProbabilityDistribution<?>) targetFieldValue).getPrediction().toString();log.info("Predicted value(ProbabilityDistribution) : " + predictedValue);} else if (targetFieldValue instanceof FieldValue) {FieldValue fieldValue = (FieldValue) targetFieldValue;predictedValue = fieldValue.asString();log.info("Predicted value(FieldValue) : " + predictedValue);} else if (targetFieldValue instanceof List) {List<String> resultList =((List<?>) targetFieldValue).stream().map(e -> ((FieldValue) e).asString()).collect(Collectors.toList());predictedValue = String.join(",", resultList);log.info("Predicted value(List) : " + predictedValue);} else {log.error("unknown type for targetFieldValue:" + targetFieldValue);}return predictedValue;}
}