一、DF转换器
- Transformer:SparkML中有很多直接对DF进行变换的类,如TF-IDF,PCA等,它们统称Transformer;
子类主要需要实现_transform方法。
from pyspark.ml import Estimator, Model
class PytorchClassifyModel(Model):
def _transform(self, dataset: DataFrame):
# ...
return spark.createDataFrame(rdd)
- Estimator:还有很多需要训练后才能对DF进行变换的类,如LR,GBDT以及NaiveBayes,它们统称Estimator,训练后产出Model,Model也是一种Transformer,因为它也能直接对DF进行变换;
子类主要需要实现_fit方法,并返回一个Model对象。
class PytorchLocalGPUClassifier(Estimator):
def _fit(self, dataset) -> PytorchClassifyModel:
# ...
return PytorchClassifyModel(model, input_cols, batch_size)
- Pipeline:因为Transformer对DF变换后仍产出DF,于是串联多个Transformer可以对DF进行流式处理。Pipeline就是专门处理流式DF的类。它可以串联多个Estimator和Transformer,经过训练后,产出PipelineModel,就是一连串Model和Transformer。因此,Pipeline继承自Estimator,PipelineModel继承自Model。
继承了Transformer和Estimator后,自然可以用在Pipeline和PipelineModel中。
二、模型参数
- Param:模型Model或者Transformer和Estimator都有一些参数可供调整,如输入列名、输出列名、PCA的主成分数量、LR的迭代次数等。SparkML中的Param只是变量的一种名字,包含了变量名称、说明和转换方式,不包含具体变量的内容。
- Params:是变量的容器;一堆参数才是所有模型的共性,因此,Transformer和Estimator都继承自Params。
Params:
_paramMap={}
_defaultParamMap={}
_set(**kwargs)
_setDefault(**kwargs)
getOrDefault(param)
Foo(Params):
inputCols = Param(Params._dummy(), "inputCols", "input cols", typeConverter=TypeConverters.toListString)
batchSize = Param(Params._dummy(), "batchSize", "batch size", typeConverter=TypeConverters.toInt)
inputColsType = Param(Params._dummy(), "inputColsType", "types of input cols")
def __init__(self, batch_size: int = 100):
self._setDefault(inputCols=['data'], inputColsType=input_cols_type)
self._set(batchSize=batch_size)
def setInputCols(self, value):
self._set(inputCols=value)
三、模型的存取
from pyspark.ml.util import MLReadable, MLWritable, DefaultParamsWriter, DefaultParamsReader
class PytorchClassifyModel(Model, MLWritable, MLReadable):
def write(self):
'''MLWritable的方法,返回一个有save方法的类,被称为Writer'''
return self
def save(self, path):
'''实际存储代码'''
DefaultParamsWriter(self).save(path)
sc = SparkSession.builder.getOrCreate().sparkContext
buffer = io.BytesIO()
torch.save(self.model, buffer)
sc.parallelize([buffer.getvalue()], 1).saveAsPickleFile(f'{path}/model.pk')
@classmethod
def read(cls):
'''MLReadable的方法,返回一个有load方法的类,被称为Reader'''
return cls
@classmethod
def load(cls, path):
'''实际读取代码'''
m: PytorchClassifyModel = DefaultParamsReader(cls).load(path)
sc = SparkSession.builder.getOrCreate().sparkContext
model_pk_str = sc.pickleFile(f'{path}/model.pk', 1).collect()[0]
buffer = io.BytesIO(model_pk_str)
m.model = torch.load(buffer)
return m