Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Machine Learning Provider Abstraction Layer #70

Open
javadba opened this issue Feb 16, 2016 · 5 comments
Open

Machine Learning Provider Abstraction Layer #70

javadba opened this issue Feb 16, 2016 · 5 comments
Labels

Comments

@javadba
Copy link
Contributor

javadba commented Feb 16, 2016

I am a coder for a team looking to consider using SparkNet with another ML library besides caffe. The intent of this Issue is to capture discussions on a ML Provider Abstraction Layer (MLPAL?) that would permit pluggable use of Caffe vs SomeOtherMLLibrary.

To the core committers: do you already have thoughts and/or a Roadmap for this? In any case our thoughts will start appearing here.

@javadba javadba changed the title Machine Learning Library Provider abstraction layer Machine Learning Provider Abstraction Layer Feb 16, 2016
@pcmoritz
Copy link
Collaborator

Thanks for bringing this up, we are very interested in this question. Providing unified APIs and data loading procedures is one of the areas where we can add value compared to what is already out there in terms of deep learning libraries. Data loading/processing is on of Spark's main strength.

Let us know about your suggestions, our current plan is to provide interfaces that can be implemented by various backends.

For the network, the interface would look like this:

trait NetInterface {
  def forward(rowIt: Iterator[Row]): Array[Row]
  def forwardBackward(rowIt: Iterator[Row])
  def getWeights(): WeightCollection
  def setWeights(weights: WeightCollection)
  def outputSchema(): StructType
}

For the Solver:

trait Solver {
  def step(rowIt: Iterator[Row])
}

Data would be loaded in a unified way from Spark DataFrames. We are working on this in the javacpp+dataframes branch, see for example this file.

@javadba
Copy link
Contributor Author

javadba commented Feb 17, 2016

Awesome! This may be a fair bit less complicated than anticipated from our perspective. I am interested in trying out that trait with another ml library.

Please suggest which test(s) to run that would best validate the usability of your NetInterface with the OtherMlLibrary framework.

@pcmoritz
Copy link
Collaborator

Thanks, the least complicated approaches are often the best.

I can sketch how we plan to implement the interface for TensorFlow.

Assume you have a TensorFlow graph definition like this (in Python):

import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
tf.initialize_all_variables().run()

You can then serialize the graph and the weights in the following way:

saver = tf.train.Saver()
saver.save(sess, "model.bin")
g = tf.Graph()
g.as_graph_def()

SparkNet would provide a TensorFlowNet class which implements the Net trait and as a constructor takes the protocol buffer definition that is generated by g.as_graph_def(). Furthermore, there would be a procedure for loading the weights saved by tf.train.Saver into a WeightCollection object and an implementation of setWeights that loads the weights into the Network.

If you are interested in pursuing this, you can start from the JavaCPP TensorFlow implementation and implement the TensorFlowNet as well as the TensorFlowSolver class. This is high priority for us, but before we get to it we would like to improve a few other things first.

@javadba
Copy link
Contributor Author

javadba commented Feb 17, 2016

OK i will first dig a bit into the javacpp-presets as a background and then
at the JavaCpp-presets for TensorFlow. ETA for an update is late Thursday
2/18

2016-02-16 21:52 GMT-08:00 Philipp Moritz [email protected]:

I can sketch how we plan to implement the interface for TensorFlow.

Assume you have a TensorFlow graph definition like this (in Python):

import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
tf.initialize_all_variables().run()

You can then serialize the graph and the weights in the following way:

saver = tf.train.Saver()
saver.save(sess, "model.bin")
g = tf.Graph()
g.as_graph_def()

SparkNet would provide a TensorFlowNet class which implements the Net
trait and as a constructor takes the protocol buffer definition that is
generated by g.as_graph_def(). Furthermore, there would be a procedure for
loading the weights saved by tf.train.Saver into a WeightCollection object
and an implementation of setWeights that loads the weights into the Network.

If you are interested in pursuing this, you can start from the JavaCPP
TensorFlow implementation
https://github.com/bytedeco/javacpp-presets/tree/master/tensorflow and
implement the TensorFlowNet as well as TensorFlowSolver class. This is high
priority for us, but before we get to it we would like to improve a few
other things first.


Reply to this email directly or view it on GitHub
#70 (comment).

@pcmoritz
Copy link
Collaborator

Great, any progress on this will be very helpful for the project, and don't hesitate to ask questions if you run into problems. We have a bunch of experience with JavaCPP by now and might be able to help you.

To get started, you can both try to run the ExampleTrainer.java from the TensorFlow preset, and also our Cifar training app in the SparkNet javacpp+dataframes branch. It is almost ready to merge now, we just haven't gotten around to create the AMI yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants