Se puede utilizar una red SOM que haga la partición de datos. Parte de lo que se ve en la práctica de ese tipo de redes y construye una que tenga tantas salidas, tantas clases, como modelos conformen tu equipo. Una vez que ya tengas la SOM, pasas por esa red el conjunto de datos que vas a utilizar para ajuste, y según a qué grupo quede asignado cada punto, se lo encargas a uno u otro modelo.

salred=red.map_vects(tea)
indcada=[list(filter(lambda ind: salred[ind][0]==este, range(len(salred)))) for este in range(numodelos)]
for mod in range(numodelos):
	dataj=tea[indcada[mod]].clone().detach().requires_grad_()
	salaj=tsa[indcada[mod]].clone().detach()

Para la prueba, pasas el conjunto de prueba por la red de agrupamientos, y según en qué grupo quede cada punto, se encarga de la predicción uno u otro modelo.

with torch.no_grad():
	ertot=0
	for punto,salreal in zip(tep,tsp):
		nm=red.map_vects(punto)[0][0]
		salida=modelos[nm](punto)
		ep=error(salida,salreal)
		ertot+=ep.item()
	print('RMSE prueba: ',math.sqrt(ertot/len(tep)))